From f7d3efe3db3421c05936d3812917c6add47820f6 Mon Sep 17 00:00:00 2001
From: Dmytro Kozlov <kozlovdmitriyy@gmail.com>
Date: Thu, 9 Mar 2023 15:53:29 +0200
Subject: [PATCH] app/vmctl: add support of basic auth and barer token (#3921)

app/vmctl: add support of basic auth and bearer token
---
 app/vmctl/auth/auth.go     | 222 +++++++++++++++++++++++++++++++++++++
 app/vmctl/flags.go         |  26 +++--
 app/vmctl/main.go          |  36 ++++--
 app/vmctl/native/client.go |  63 +----------
 docs/CHANGELOG.md          |   1 +
 5 files changed, 273 insertions(+), 75 deletions(-)
 create mode 100644 app/vmctl/auth/auth.go

diff --git a/app/vmctl/auth/auth.go b/app/vmctl/auth/auth.go
new file mode 100644
index 000000000..7c1413cce
--- /dev/null
+++ b/app/vmctl/auth/auth.go
@@ -0,0 +1,222 @@
+package auth
+
+import (
+	"encoding/base64"
+	"fmt"
+	"net/http"
+	"strings"
+	"sync"
+
+	"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
+)
+
+// HTTPClientConfig represents http client config.
+type HTTPClientConfig struct {
+	BasicAuth   *BasicAuthConfig
+	BearerToken string
+	Headers     string
+}
+
+// NewConfig creates auth config for the given hcc.
+func (hcc *HTTPClientConfig) NewConfig() (*Config, error) {
+	opts := &Options{
+		BasicAuth:   hcc.BasicAuth,
+		BearerToken: hcc.BearerToken,
+		Headers:     hcc.Headers,
+	}
+	return opts.NewConfig()
+}
+
+// BasicAuthConfig represents basic auth config.
+type BasicAuthConfig struct {
+	Username     string
+	Password     string
+	PasswordFile string
+}
+
+// ConfigOptions options which helps build Config
+type ConfigOptions func(config *HTTPClientConfig)
+
+// Generate returns Config based on the given params
+func Generate(filterOptions ...ConfigOptions) (*Config, error) {
+	authCfg := &HTTPClientConfig{}
+	for _, option := range filterOptions {
+		option(authCfg)
+	}
+
+	return authCfg.NewConfig()
+}
+
+// WithBasicAuth returns AuthConfigOptions and initialized BasicAuthConfig based on given params
+func WithBasicAuth(username, password string) ConfigOptions {
+	return func(config *HTTPClientConfig) {
+		if username != "" || password != "" {
+			config.BasicAuth = &BasicAuthConfig{
+				Username: username,
+				Password: password,
+			}
+		}
+	}
+}
+
+// WithBearer returns AuthConfigOptions and set BearerToken or BearerTokenFile based on given params
+func WithBearer(token string) ConfigOptions {
+	return func(config *HTTPClientConfig) {
+		if token != "" {
+			config.BearerToken = token
+		}
+	}
+}
+
+// WithHeaders returns AuthConfigOptions and set Headers based on the given params
+func WithHeaders(headers string) ConfigOptions {
+	return func(config *HTTPClientConfig) {
+		if headers != "" {
+			config.Headers = headers
+		}
+	}
+}
+
+// Config is auth config.
+type Config struct {
+	getAuthHeader      func() string
+	authHeaderLock     sync.Mutex
+	authHeader         string
+	authHeaderDeadline uint64
+
+	headers []keyValue
+
+	authDigest string
+}
+
+// SetHeaders sets the configured ac headers to req.
+func (ac *Config) SetHeaders(req *http.Request, setAuthHeader bool) {
+	reqHeaders := req.Header
+	for _, h := range ac.headers {
+		reqHeaders.Set(h.key, h.value)
+	}
+	if setAuthHeader {
+		if ah := ac.GetAuthHeader(); ah != "" {
+			reqHeaders.Set("Authorization", ah)
+		}
+	}
+}
+
+// GetAuthHeader returns optional `Authorization: ...` http header.
+func (ac *Config) GetAuthHeader() string {
+	f := ac.getAuthHeader
+	if f == nil {
+		return ""
+	}
+	ac.authHeaderLock.Lock()
+	defer ac.authHeaderLock.Unlock()
+	if fasttime.UnixTimestamp() > ac.authHeaderDeadline {
+		ac.authHeader = f()
+		// Cache the authHeader for a second.
+		ac.authHeaderDeadline = fasttime.UnixTimestamp() + 1
+	}
+	return ac.authHeader
+}
+
+type authContext struct {
+	// getAuthHeader must return <value> for 'Authorization: <value>' http request header
+	getAuthHeader func() string
+
+	// authDigest must contain the digest for the used authorization
+	// The digest must be changed whenever the original config changes.
+	authDigest string
+}
+
+func (ac *authContext) initFromBasicAuthConfig(ba *BasicAuthConfig) error {
+	if ba.Username == "" {
+		return fmt.Errorf("missing `username` in `basic_auth` section")
+	}
+	if ba.Password != "" {
+		ac.getAuthHeader = func() string {
+			// See https://en.wikipedia.org/wiki/Basic_access_authentication
+			token := ba.Username + ":" + ba.Password
+			token64 := base64.StdEncoding.EncodeToString([]byte(token))
+			return "Basic " + token64
+		}
+		ac.authDigest = fmt.Sprintf("basic(username=%q, password=%q)", ba.Username, ba.Password)
+		return nil
+	}
+	return nil
+}
+
+func (ac *authContext) initFromBearerToken(bearerToken string) error {
+	ac.getAuthHeader = func() string {
+		return "Bearer " + bearerToken
+	}
+	ac.authDigest = fmt.Sprintf("bearer(token=%q)", bearerToken)
+	return nil
+}
+
+// Options contain options, which must be passed to NewConfig.
+type Options struct {
+	// BasicAuth contains optional BasicAuthConfig.
+	BasicAuth *BasicAuthConfig
+
+	// BearerToken contains optional bearer token.
+	BearerToken string
+
+	// Headers contains optional http request headers in the form 'Foo: bar'.
+	Headers string
+}
+
+// NewConfig creates auth config from the given opts.
+func (opts *Options) NewConfig() (*Config, error) {
+	var ac authContext
+	if opts.BasicAuth != nil {
+		if ac.getAuthHeader != nil {
+			return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`")
+		}
+		if err := ac.initFromBasicAuthConfig(opts.BasicAuth); err != nil {
+			return nil, err
+		}
+	}
+	if opts.BearerToken != "" {
+		if ac.getAuthHeader != nil {
+			return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`")
+		}
+		if err := ac.initFromBearerToken(opts.BearerToken); err != nil {
+			return nil, err
+		}
+	}
+
+	headers, err := parseHeaders(opts.Headers)
+	if err != nil {
+		return nil, err
+	}
+	c := &Config{
+		getAuthHeader: ac.getAuthHeader,
+		headers:       headers,
+		authDigest:    ac.authDigest,
+	}
+	return c, nil
+}
+
+type keyValue struct {
+	key   string
+	value string
+}
+
+func parseHeaders(headers string) ([]keyValue, error) {
+	if len(headers) == 0 {
+		return nil, nil
+	}
+
+	var headersSplitByDelimiter = strings.Split(headers, "^^")
+
+	kvs := make([]keyValue, len(headersSplitByDelimiter))
+	for i, h := range headersSplitByDelimiter {
+		n := strings.IndexByte(h, ':')
+		if n < 0 {
+			return nil, fmt.Errorf(`missing ':' in header %q; expecting "key: value" format`, h)
+		}
+		kv := &kvs[i]
+		kv.key = strings.TrimSpace(h[:n])
+		kv.value = strings.TrimSpace(h[n+1:])
+	}
+	return kvs, nil
+}
diff --git a/app/vmctl/flags.go b/app/vmctl/flags.go
index 1fcc3033e..5eae96c9c 100644
--- a/app/vmctl/flags.go
+++ b/app/vmctl/flags.go
@@ -327,15 +327,17 @@ const (
 
 	vmNativeDisableHTTPKeepAlive = "vm-native-disable-http-keep-alive"
 
-	vmNativeSrcAddr     = "vm-native-src-addr"
-	vmNativeSrcUser     = "vm-native-src-user"
-	vmNativeSrcPassword = "vm-native-src-password"
-	vmNativeSrcHeaders  = "vm-native-src-headers"
+	vmNativeSrcAddr        = "vm-native-src-addr"
+	vmNativeSrcUser        = "vm-native-src-user"
+	vmNativeSrcPassword    = "vm-native-src-password"
+	vmNativeSrcHeaders     = "vm-native-src-headers"
+	vmNativeSrcBearerToken = "vm-native-src-bearer-token"
 
-	vmNativeDstAddr     = "vm-native-dst-addr"
-	vmNativeDstUser     = "vm-native-dst-user"
-	vmNativeDstPassword = "vm-native-dst-password"
-	vmNativeDstHeaders  = "vm-native-dst-headers"
+	vmNativeDstAddr        = "vm-native-dst-addr"
+	vmNativeDstUser        = "vm-native-dst-user"
+	vmNativeDstPassword    = "vm-native-dst-password"
+	vmNativeDstHeaders     = "vm-native-dst-headers"
+	vmNativeDstBearerToken = "vm-native-dst-bearer-token"
 )
 
 var (
@@ -388,6 +390,10 @@ var (
 				"For example, --vm-native-src-headers='My-Auth:foobar' would send 'My-Auth: foobar' HTTP header with every request to the corresponding source address. \n" +
 				"Multiple headers must be delimited by '^^': --vm-native-src-headers='header1:value1^^header2:value2'",
 		},
+		&cli.StringFlag{
+			Name:  vmNativeSrcBearerToken,
+			Usage: "Optional bearer auth token to use for the corresponding `--vm-native-src-addr`",
+		},
 		&cli.StringFlag{
 			Name: vmNativeDstAddr,
 			Usage: "VictoriaMetrics address to perform import to. \n" +
@@ -411,6 +417,10 @@ var (
 				"For example, --vm-native-dst-headers='My-Auth:foobar' would send 'My-Auth: foobar' HTTP header with every request to the corresponding destination address. \n" +
 				"Multiple headers must be delimited by '^^': --vm-native-dst-headers='header1:value1^^header2:value2'",
 		},
+		&cli.StringFlag{
+			Name:  vmNativeDstBearerToken,
+			Usage: "Optional bearer auth token to use for the corresponding `--vm-native-dst-addr`",
+		},
 		&cli.StringSliceFlag{
 			Name:  vmExtraLabel,
 			Value: nil,
diff --git a/app/vmctl/main.go b/app/vmctl/main.go
index 98db814b1..6fc1c2dd9 100644
--- a/app/vmctl/main.go
+++ b/app/vmctl/main.go
@@ -11,6 +11,7 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/VictoriaMetrics/VictoriaMetrics/app/vmctl/auth"
 	"github.com/VictoriaMetrics/VictoriaMetrics/app/vmctl/backoff"
 	"github.com/VictoriaMetrics/VictoriaMetrics/app/vmctl/native"
 	"github.com/VictoriaMetrics/VictoriaMetrics/app/vmctl/remoteread"
@@ -199,6 +200,26 @@ func main() {
 						return fmt.Errorf("flag %q can't be empty", vmNativeFilterMatch)
 					}
 
+					var srcExtraLabels []string
+					srcAddr := strings.Trim(c.String(vmNativeSrcAddr), "/")
+					srcAuthConfig, err := auth.Generate(
+						auth.WithBasicAuth(c.String(vmNativeSrcUser), c.String(vmNativeSrcPassword)),
+						auth.WithBearer(c.String(vmNativeSrcBearerToken)),
+						auth.WithHeaders(c.String(vmNativeSrcHeaders)))
+					if err != nil {
+						return fmt.Errorf("error initilize auth config for source: %s", srcAddr)
+					}
+
+					dstAddr := strings.Trim(c.String(vmNativeDstAddr), "/")
+					dstExtraLabels := c.StringSlice(vmExtraLabel)
+					dstAuthConfig, err := auth.Generate(
+						auth.WithBasicAuth(c.String(vmNativeDstUser), c.String(vmNativeDstPassword)),
+						auth.WithBearer(c.String(vmNativeDstBearerToken)),
+						auth.WithHeaders(c.String(vmNativeDstHeaders)))
+					if err != nil {
+						return fmt.Errorf("error initilize auth config for destination: %s", dstAddr)
+					}
+
 					p := vmNativeProcessor{
 						rateLimit:    c.Int64(vmRateLimit),
 						interCluster: c.Bool(vmInterCluster),
@@ -209,18 +230,15 @@ func main() {
 							Chunk:     c.String(vmNativeStepInterval),
 						},
 						src: &native.Client{
-							Addr:                 strings.Trim(c.String(vmNativeSrcAddr), "/"),
-							User:                 c.String(vmNativeSrcUser),
-							Password:             c.String(vmNativeSrcPassword),
-							Headers:              c.String(vmNativeSrcHeaders),
+							AuthCfg:              srcAuthConfig,
+							Addr:                 srcAddr,
+							ExtraLabels:          srcExtraLabels,
 							DisableHTTPKeepAlive: c.Bool(vmNativeDisableHTTPKeepAlive),
 						},
 						dst: &native.Client{
-							Addr:                 strings.Trim(c.String(vmNativeDstAddr), "/"),
-							User:                 c.String(vmNativeDstUser),
-							Password:             c.String(vmNativeDstPassword),
-							ExtraLabels:          c.StringSlice(vmExtraLabel),
-							Headers:              c.String(vmNativeDstHeaders),
+							AuthCfg:              dstAuthConfig,
+							Addr:                 dstAddr,
+							ExtraLabels:          dstExtraLabels,
 							DisableHTTPKeepAlive: c.Bool(vmNativeDisableHTTPKeepAlive),
 						},
 						backoff: backoff.New(),
diff --git a/app/vmctl/native/client.go b/app/vmctl/native/client.go
index 474dbe81f..1b9f779a7 100644
--- a/app/vmctl/native/client.go
+++ b/app/vmctl/native/client.go
@@ -6,7 +6,8 @@ import (
 	"fmt"
 	"io"
 	"net/http"
-	"strings"
+
+	"github.com/VictoriaMetrics/VictoriaMetrics/app/vmctl/auth"
 )
 
 const (
@@ -18,11 +19,9 @@ const (
 // Client is an HTTP client for exporting and importing
 // time series via native protocol.
 type Client struct {
+	AuthCfg              *auth.Config
 	Addr                 string
-	User                 string
-	Password             string
 	ExtraLabels          []string
-	Headers              string
 	DisableHTTPKeepAlive bool
 }
 
@@ -93,15 +92,6 @@ func (c *Client) ImportPipe(ctx context.Context, dstURL string, pr *io.PipeReade
 		return fmt.Errorf("cannot create import request to %q: %s", c.Addr, err)
 	}
 
-	parsedHeaders, err := parseHeaders(c.Headers)
-	if err != nil {
-		return err
-	}
-
-	for _, header := range parsedHeaders {
-		req.Header.Set(header.key, header.value)
-	}
-
 	importResp, err := c.do(req, http.StatusNoContent)
 	if err != nil {
 		return fmt.Errorf("import request failed: %s", err)
@@ -132,15 +122,6 @@ func (c *Client) ExportPipe(ctx context.Context, url string, f Filter) (io.ReadC
 	// disable compression since it is meaningless for native format
 	req.Header.Set("Accept-Encoding", "identity")
 
-	parsedHeaders, err := parseHeaders(c.Headers)
-	if err != nil {
-		return nil, err
-	}
-
-	for _, header := range parsedHeaders {
-		req.Header.Set(header.key, header.value)
-	}
-
 	resp, err := c.do(req, http.StatusOK)
 	if err != nil {
 		return nil, fmt.Errorf("export request failed: %w", err)
@@ -165,15 +146,6 @@ func (c *Client) GetSourceTenants(ctx context.Context, f Filter) ([]string, erro
 	}
 	req.URL.RawQuery = params.Encode()
 
-	parsedHeaders, err := parseHeaders(c.Headers)
-	if err != nil {
-		return nil, err
-	}
-
-	for _, header := range parsedHeaders {
-		req.Header.Set(header.key, header.value)
-	}
-
 	resp, err := c.do(req, http.StatusOK)
 	if err != nil {
 		return nil, fmt.Errorf("tenants request failed: %s", err)
@@ -194,8 +166,8 @@ func (c *Client) GetSourceTenants(ctx context.Context, f Filter) ([]string, erro
 }
 
 func (c *Client) do(req *http.Request, expSC int) (*http.Response, error) {
-	if c.User != "" {
-		req.SetBasicAuth(c.User, c.Password)
+	if c.AuthCfg != nil {
+		c.AuthCfg.SetHeaders(req, true)
 	}
 	var httpClient = &http.Client{Transport: &http.Transport{DisableKeepAlives: c.DisableHTTPKeepAlive}}
 	resp, err := httpClient.Do(req)
@@ -212,28 +184,3 @@ func (c *Client) do(req *http.Request, expSC int) (*http.Response, error) {
 	}
 	return resp, err
 }
-
-type keyValue struct {
-	key   string
-	value string
-}
-
-func parseHeaders(headers string) ([]keyValue, error) {
-	if len(headers) == 0 {
-		return nil, nil
-	}
-
-	var headersSplitByDelimiter = strings.Split(headers, "^^")
-
-	kvs := make([]keyValue, len(headersSplitByDelimiter))
-	for i, h := range headersSplitByDelimiter {
-		n := strings.IndexByte(h, ':')
-		if n < 0 {
-			return nil, fmt.Errorf(`missing ':' in header %q; expecting "key: value" format`, h)
-		}
-		kv := &kvs[i]
-		kv.key = strings.TrimSpace(h[:n])
-		kv.value = strings.TrimSpace(h[n+1:])
-	}
-	return kvs, nil
-}
diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index 63dab58b1..c160e5e5c 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -22,6 +22,7 @@ The following tip changes can be tested by building VictoriaMetrics components f
 * FEATURE: [vmctl](https://docs.victoriametrics.com/vmctl.html): add `--vm-native-src-headers` and `--vm-native-dst-headers` command-line flags, which can be used for setting custom HTTP headers during [vm-native migration mode](https://docs.victoriametrics.com/vmctl.html#native-protocol). Thanks to @baconmania for [the pull request](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/3906).
 * FEATURE: [vmalert](https://docs.victoriametrics.com/vmalert.html): log number of configration files found for each specified `-rule` command-line flag. 
 * FEATURE: [vmctl](https://docs.victoriametrics.com/vmctl.html): add `--vm-native-disable-http-keep-alive` command-line flags to allow `vmctl` to use non-persistent HTTP connections in [vm-native migration mode](https://docs.victoriametrics.com/vmctl.html#native-protocol). Thanks to @baconmania for [the pull request](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/3909).
+* FEATURE: [vmctl](https://docs.victoriametrics.com/vmctl.html): add `--vm-native-src-bearer-token` and `--vm-native-dst-bearer-token` command-line flags, which can be used for setting custom HTTP headers during [vm-native migration mode](https://docs.victoriametrics.com/vmctl.html#native-protocol). See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/3835).
 
 * BUGFIX: [vmagent](https://docs.victoriametrics.com/vmagent.html): fix panic when [writing data to Kafka](https://docs.victoriametrics.com/vmagent.html#writing-metrics-to-kafka). The panic has been introduced in [v1.88.0](https://docs.victoriametrics.com/CHANGELOG.html#v1880).
 * BUGFIX: prevent from possible `invalid memory address or nil pointer dereference` panic during [background merge](https://docs.victoriametrics.com/#storage). The issue has been introduced at [v1.85.0](https://docs.victoriametrics.com/CHANGELOG.html#v1850). See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/3897).