lib/promscrape: reload auth tokens from files every second

Previously auth tokens were loaded at startup and couldn't be updated without vmagent restart.
Now there is no need in vmagent restart.

Updates https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1297
This commit is contained in:
Aliaksandr Valialkin 2021-05-14 20:00:05 +03:00
parent a15145e597
commit 733706e6c6
7 changed files with 142 additions and 118 deletions

View file

@ -5,6 +5,7 @@ sort: 15
# CHANGELOG # CHANGELOG
* FEATURE: vminsert: add support for data ingestion via other `vminsert` nodes. This allows building multi-level data ingestion paths in VictoriaMetrics cluster by writing data from one level of `vminsert` nodes to another level of `vminsert` nodes. See [these docs](https://docs.victoriametrics.com/Cluster-VictoriaMetrics.html#multi-level-cluster-setup) and [this comment](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/541#issuecomment-835487858) for details. * FEATURE: vminsert: add support for data ingestion via other `vminsert` nodes. This allows building multi-level data ingestion paths in VictoriaMetrics cluster by writing data from one level of `vminsert` nodes to another level of `vminsert` nodes. See [these docs](https://docs.victoriametrics.com/Cluster-VictoriaMetrics.html#multi-level-cluster-setup) and [this comment](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/541#issuecomment-835487858) for details.
* FEATURE: vmagent: reload `bearer_token_file`, `credentials_file` and `password_file` contents every second. This allows dynamically changing the contents of these files during target scraping and service discovery without the need to restart `vmagent`. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1297).
* FEATURE: vmalert: add flag to control behaviour on startup for state restore errors. Such errors were returned and logged before as well. But now user can specify whether to just log these errors (`-remoteRead.ignoreRestoreErrors=true`) or to stop the process (`-remoteRead.ignoreRestoreErrors=false`). The latter is important when VM isn't ready yet to serve queries from vmalert and it needs to wait. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1252). * FEATURE: vmalert: add flag to control behaviour on startup for state restore errors. Such errors were returned and logged before as well. But now user can specify whether to just log these errors (`-remoteRead.ignoreRestoreErrors=true`) or to stop the process (`-remoteRead.ignoreRestoreErrors=false`). The latter is important when VM isn't ready yet to serve queries from vmalert and it needs to wait. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1252).
* FEATURE: vmalert: add ability to pass `round_digits` query arg to datasource via `-datasource.roundDigits` command-line flag. This can be used for limiting the number of decimal digits after the point in recording rule results. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/525). * FEATURE: vmalert: add ability to pass `round_digits` query arg to datasource via `-datasource.roundDigits` command-line flag. This can be used for limiting the number of decimal digits after the point in recording rule results. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/525).
* FEATURE: return `X-Server-Hostname` header in http responses of all the VictoriaMetrics components. This should simplify tracing the origin server behind a load balancer or behind auth proxy during troubleshooting. * FEATURE: return `X-Server-Hostname` header in http responses of all the VictoriaMetrics components. This should simplify tracing the origin server behind a load balancer or behind auth proxy during troubleshooting.

View file

@ -7,6 +7,10 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"sync"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
) )
// TLSConfig represents TLS config. // TLSConfig represents TLS config.
@ -56,22 +60,40 @@ type ProxyClientConfig struct {
// Config is auth config. // Config is auth config.
type Config struct { type Config struct {
// Optional `Authorization` header.
//
// It may contain `Basic ....` or `Bearer ....` string.
Authorization string
// Optional TLS config // Optional TLS config
TLSRootCA *x509.CertPool TLSRootCA *x509.CertPool
TLSCertificate *tls.Certificate TLSCertificate *tls.Certificate
TLSServerName string TLSServerName string
TLSInsecureSkipVerify bool TLSInsecureSkipVerify bool
getAuthHeader func() string
authHeaderLock sync.Mutex
authHeader string
authHeaderDeadline uint64
authDigest string
} }
// String returns human-(un)readable representation for cfg. // GetAuthHeader returns optional `Authorization: ...` http header.
func (ac *Config) GetAuthHeader() string {
f := ac.getAuthHeader
if f == nil {
return ""
}
ac.authHeaderLock.Lock()
defer ac.authHeaderLock.Unlock()
if fasttime.UnixTimestamp() > ac.authHeaderDeadline {
ac.authHeader = f()
// Cache the authHeader for a second.
ac.authHeaderDeadline = fasttime.UnixTimestamp() + 1
}
return ac.authHeader
}
// String returns human-readable representation for ac.
func (ac *Config) String() string { func (ac *Config) String() string {
return fmt.Sprintf("Authorization=%s, TLSRootCA=%s, TLSCertificate=%s, TLSServerName=%s, TLSInsecureSkipVerify=%v", return fmt.Sprintf("AuthDigest=%s, TLSRootCA=%s, TLSCertificate=%s, TLSServerName=%s, TLSInsecureSkipVerify=%v",
ac.Authorization, ac.tlsRootCAString(), ac.tlsCertificateString(), ac.TLSServerName, ac.TLSInsecureSkipVerify) ac.authDigest, ac.tlsRootCAString(), ac.tlsCertificateString(), ac.TLSServerName, ac.TLSInsecureSkipVerify)
} }
func (ac *Config) tlsRootCAString() string { func (ac *Config) tlsRootCAString() string {
@ -119,70 +141,94 @@ func (pcc *ProxyClientConfig) NewConfig(baseDir string) (*Config, error) {
// NewConfig creates auth config from the given args. // NewConfig creates auth config from the given args.
func NewConfig(baseDir string, az *Authorization, basicAuth *BasicAuthConfig, bearerToken, bearerTokenFile string, tlsConfig *TLSConfig) (*Config, error) { func NewConfig(baseDir string, az *Authorization, basicAuth *BasicAuthConfig, bearerToken, bearerTokenFile string, tlsConfig *TLSConfig) (*Config, error) {
var authorization string var getAuthHeader func() string
authDigest := ""
if az != nil { if az != nil {
azType := "Bearer" azType := "Bearer"
if az.Type != "" { if az.Type != "" {
azType = az.Type azType = az.Type
} }
azToken := az.Credentials
if az.CredentialsFile != "" { if az.CredentialsFile != "" {
if az.Credentials != "" { if az.Credentials != "" {
return nil, fmt.Errorf("both `credentials`=%q and `credentials_file`=%q are set", az.Credentials, az.CredentialsFile) return nil, fmt.Errorf("both `credentials`=%q and `credentials_file`=%q are set", az.Credentials, az.CredentialsFile)
} }
path := getFilepath(baseDir, az.CredentialsFile) filePath := getFilepath(baseDir, az.CredentialsFile)
token, err := readPasswordFromFile(path) getAuthHeader = func() string {
if err != nil { token, err := readPasswordFromFile(filePath)
return nil, fmt.Errorf("cannot read credentials from `credentials_file`=%q: %w", az.CredentialsFile, err) if err != nil {
logger.Errorf("cannot read credentials from `credentials_file`=%q: %s", az.CredentialsFile, err)
return ""
}
return azType + " " + token
} }
azToken = token authDigest = fmt.Sprintf("custom(type=%q, credsFile=%q)", az.Type, filePath)
} else {
getAuthHeader = func() string {
return azType + " " + az.Credentials
}
authDigest = fmt.Sprintf("custom(type=%q, creds=%q)", az.Type, az.Credentials)
} }
authorization = azType + " " + azToken
} }
if basicAuth != nil { if basicAuth != nil {
if authorization != "" { if getAuthHeader != nil {
return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`") return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`")
} }
if basicAuth.Username == "" { if basicAuth.Username == "" {
return nil, fmt.Errorf("missing `username` in `basic_auth` section") return nil, fmt.Errorf("missing `username` in `basic_auth` section")
} }
username := basicAuth.Username
password := basicAuth.Password
if basicAuth.PasswordFile != "" { if basicAuth.PasswordFile != "" {
if basicAuth.Password != "" { if basicAuth.Password != "" {
return nil, fmt.Errorf("both `password`=%q and `password_file`=%q are set in `basic_auth` section", basicAuth.Password, basicAuth.PasswordFile) return nil, fmt.Errorf("both `password`=%q and `password_file`=%q are set in `basic_auth` section", basicAuth.Password, basicAuth.PasswordFile)
} }
path := getFilepath(baseDir, basicAuth.PasswordFile) filePath := getFilepath(baseDir, basicAuth.PasswordFile)
pass, err := readPasswordFromFile(path) getAuthHeader = func() string {
if err != nil { password, err := readPasswordFromFile(filePath)
return nil, fmt.Errorf("cannot read password from `password_file`=%q set in `basic_auth` section: %w", basicAuth.PasswordFile, err) if err != nil {
logger.Errorf("cannot read password from `password_file`=%q set in `basic_auth` section: %s", basicAuth.PasswordFile, err)
return ""
}
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := basicAuth.Username + ":" + password
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
} }
password = pass authDigest = fmt.Sprintf("basic(username=%q, passwordFile=%q)", basicAuth.Username, filePath)
} else {
getAuthHeader = func() string {
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := basicAuth.Username + ":" + basicAuth.Password
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
}
authDigest = fmt.Sprintf("basic(username=%q, password=%q)", basicAuth.Username, basicAuth.Password)
} }
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := username + ":" + password
token64 := base64.StdEncoding.EncodeToString([]byte(token))
authorization = "Basic " + token64
} }
if bearerTokenFile != "" { if bearerTokenFile != "" {
if authorization != "" { if getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token_file`") return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token_file`")
} }
if bearerToken != "" { if bearerToken != "" {
return nil, fmt.Errorf("both `bearer_token`=%q and `bearer_token_file`=%q are set", bearerToken, bearerTokenFile) return nil, fmt.Errorf("both `bearer_token`=%q and `bearer_token_file`=%q are set", bearerToken, bearerTokenFile)
} }
path := getFilepath(baseDir, bearerTokenFile) filePath := getFilepath(baseDir, bearerTokenFile)
token, err := readPasswordFromFile(path) getAuthHeader = func() string {
if err != nil { token, err := readPasswordFromFile(filePath)
return nil, fmt.Errorf("cannot read bearer token from `bearer_token_file`=%q: %w", bearerTokenFile, err) if err != nil {
logger.Errorf("cannot read bearer token from `bearer_token_file`=%q: %s", bearerTokenFile, err)
return ""
}
return "Bearer " + token
} }
authorization = "Bearer " + token authDigest = fmt.Sprintf("bearer(tokenFile=%q)", filePath)
} }
if bearerToken != "" { if bearerToken != "" {
if authorization != "" { if getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`") return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`")
} }
authorization = "Bearer " + bearerToken getAuthHeader = func() string {
return "Bearer " + bearerToken
}
authDigest = fmt.Sprintf("bearer(token=%q)", bearerToken)
} }
var tlsRootCA *x509.CertPool var tlsRootCA *x509.CertPool
var tlsCertificate *tls.Certificate var tlsCertificate *tls.Certificate
@ -213,11 +259,13 @@ func NewConfig(baseDir string, az *Authorization, basicAuth *BasicAuthConfig, be
} }
} }
ac := &Config{ ac := &Config{
Authorization: authorization,
TLSRootCA: tlsRootCA, TLSRootCA: tlsRootCA,
TLSCertificate: tlsCertificate, TLSCertificate: tlsCertificate,
TLSServerName: tlsServerName, TLSServerName: tlsServerName,
TLSInsecureSkipVerify: tlsInsecureSkipVerify, TLSInsecureSkipVerify: tlsInsecureSkipVerify,
getAuthHeader: getAuthHeader,
authDigest: authDigest,
} }
return ac, nil return ac, nil
} }

View file

@ -47,8 +47,8 @@ type client struct {
scrapeTimeoutSecondsStr string scrapeTimeoutSecondsStr string
host string host string
requestURI string requestURI string
authHeader string getAuthHeader func() string
proxyAuthHeader string getProxyAuthHeader func() string
denyRedirects bool denyRedirects bool
disableCompression bool disableCompression bool
disableKeepAlive bool disableKeepAlive bool
@ -64,7 +64,7 @@ func newClient(sw *ScrapeWork) *client {
if isTLS { if isTLS {
tlsCfg = sw.AuthConfig.NewTLSConfig() tlsCfg = sw.AuthConfig.NewTLSConfig()
} }
proxyAuthHeader := "" getProxyAuthHeader := func() string { return "" }
proxyURL := sw.ProxyURL proxyURL := sw.ProxyURL
if !isTLS && proxyURL.IsHTTPOrHTTPS() { if !isTLS && proxyURL.IsHTTPOrHTTPS() {
// Send full sw.ScrapeURL in requests to a proxy host for non-TLS scrape targets // Send full sw.ScrapeURL in requests to a proxy host for non-TLS scrape targets
@ -77,7 +77,9 @@ func newClient(sw *ScrapeWork) *client {
if isTLS { if isTLS {
tlsCfg = sw.ProxyAuthConfig.NewTLSConfig() tlsCfg = sw.ProxyAuthConfig.NewTLSConfig()
} }
proxyAuthHeader = proxyURL.GetAuthHeader(sw.ProxyAuthConfig) getProxyAuthHeader = func() string {
return proxyURL.GetAuthHeader(sw.ProxyAuthConfig)
}
proxyURL = proxy.URL{} proxyURL = proxy.URL{}
} }
if !strings.Contains(host, ":") { if !strings.Contains(host, ":") {
@ -144,8 +146,8 @@ func newClient(sw *ScrapeWork) *client {
scrapeTimeoutSecondsStr: fmt.Sprintf("%.3f", sw.ScrapeTimeout.Seconds()), scrapeTimeoutSecondsStr: fmt.Sprintf("%.3f", sw.ScrapeTimeout.Seconds()),
host: host, host: host,
requestURI: requestURI, requestURI: requestURI,
authHeader: sw.AuthConfig.Authorization, getAuthHeader: sw.AuthConfig.GetAuthHeader,
proxyAuthHeader: proxyAuthHeader, getProxyAuthHeader: getProxyAuthHeader,
denyRedirects: sw.DenyRedirects, denyRedirects: sw.DenyRedirects,
disableCompression: sw.DisableCompression, disableCompression: sw.DisableCompression,
disableKeepAlive: sw.DisableKeepAlive, disableKeepAlive: sw.DisableKeepAlive,
@ -169,11 +171,11 @@ func (c *client) GetStreamReader() (*streamReader, error) {
// Set X-Prometheus-Scrape-Timeout-Seconds like Prometheus does, since it is used by some exporters such as PushProx. // Set X-Prometheus-Scrape-Timeout-Seconds like Prometheus does, since it is used by some exporters such as PushProx.
// See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1179#issuecomment-813117162 // See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1179#issuecomment-813117162
req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", c.scrapeTimeoutSecondsStr) req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", c.scrapeTimeoutSecondsStr)
if c.authHeader != "" { if ah := c.getAuthHeader(); ah != "" {
req.Header.Set("Authorization", c.authHeader) req.Header.Set("Authorization", ah)
} }
if c.proxyAuthHeader != "" { if ah := c.getProxyAuthHeader(); ah != "" {
req.Header.Set("Proxy-Authorization", c.proxyAuthHeader) req.Header.Set("Proxy-Authorization", ah)
} }
resp, err := c.sc.Do(req) resp, err := c.sc.Do(req)
if err != nil { if err != nil {
@ -209,11 +211,11 @@ func (c *client) ReadData(dst []byte) ([]byte, error) {
// Set X-Prometheus-Scrape-Timeout-Seconds like Prometheus does, since it is used by some exporters such as PushProx. // Set X-Prometheus-Scrape-Timeout-Seconds like Prometheus does, since it is used by some exporters such as PushProx.
// See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1179#issuecomment-813117162 // See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1179#issuecomment-813117162
req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", c.scrapeTimeoutSecondsStr) req.Header.Set("X-Prometheus-Scrape-Timeout-Seconds", c.scrapeTimeoutSecondsStr)
if c.authHeader != "" { if ah := c.getAuthHeader(); ah != "" {
req.Header.Set("Authorization", c.authHeader) req.Header.Set("Authorization", ah)
} }
if c.proxyAuthHeader != "" { if ah := c.getProxyAuthHeader(); ah != "" {
req.Header.Set("Proxy-Authorization", c.proxyAuthHeader) req.Header.Set("Proxy-Authorization", ah)
} }
if !*disableCompression && !c.disableCompression { if !*disableCompression && !c.disableCompression {
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")

View file

@ -301,7 +301,7 @@ scrape_configs:
- job_name: x - job_name: x
basic_auth: basic_auth:
username: foobar username: foobar
password_file: /non_existing_file.pass password_file: ['foobar']
static_configs: static_configs:
- targets: ["a"] - targets: ["a"]
`) `)
@ -355,7 +355,7 @@ scrape_configs:
f(` f(`
scrape_configs: scrape_configs:
- job_name: x - job_name: x
bearer_token_file: non_existing_file.bearer bearer_token_file: [foobar]
static_configs: static_configs:
- targets: ["a"] - targets: ["a"]
`) `)
@ -778,28 +778,18 @@ scrape_configs:
params: params:
p: ["x&y", "="] p: ["x&y", "="]
xaa: xaa:
bearer_token: xyz
proxy_url: http://foo.bar proxy_url: http://foo.bar
proxy_basic_auth:
username: foo
password: bar
static_configs: static_configs:
- targets: ["foo.bar", "aaa"] - targets: ["foo.bar", "aaa"]
labels: labels:
x: y x: y
- job_name: qwer - job_name: qwer
basic_auth:
username: user
password: pass
tls_config: tls_config:
server_name: foobar server_name: foobar
insecure_skip_verify: true insecure_skip_verify: true
static_configs: static_configs:
- targets: [1.2.3.4] - targets: [1.2.3.4]
- job_name: asdf - job_name: asdf
authorization:
type: xyz
credentials: abc
static_configs: static_configs:
- targets: [foobar] - targets: [foobar]
`, []*ScrapeWork{ `, []*ScrapeWork{
@ -840,12 +830,8 @@ scrape_configs:
Value: "y", Value: "y",
}, },
}, },
AuthConfig: &promauth.Config{ AuthConfig: &promauth.Config{},
Authorization: "Bearer xyz", ProxyAuthConfig: &promauth.Config{},
},
ProxyAuthConfig: &promauth.Config{
Authorization: "Basic Zm9vOmJhcg==",
},
ProxyURL: proxy.MustNewURL("http://foo.bar"), ProxyURL: proxy.MustNewURL("http://foo.bar"),
jobNameOriginal: "foo", jobNameOriginal: "foo",
}, },
@ -886,12 +872,8 @@ scrape_configs:
Value: "y", Value: "y",
}, },
}, },
AuthConfig: &promauth.Config{ AuthConfig: &promauth.Config{},
Authorization: "Bearer xyz", ProxyAuthConfig: &promauth.Config{},
},
ProxyAuthConfig: &promauth.Config{
Authorization: "Basic Zm9vOmJhcg==",
},
ProxyURL: proxy.MustNewURL("http://foo.bar"), ProxyURL: proxy.MustNewURL("http://foo.bar"),
jobNameOriginal: "foo", jobNameOriginal: "foo",
}, },
@ -922,7 +904,6 @@ scrape_configs:
}, },
}, },
AuthConfig: &promauth.Config{ AuthConfig: &promauth.Config{
Authorization: "Basic dXNlcjpwYXNz",
TLSServerName: "foobar", TLSServerName: "foobar",
TLSInsecureSkipVerify: true, TLSInsecureSkipVerify: true,
}, },
@ -955,9 +936,7 @@ scrape_configs:
Value: "asdf", Value: "asdf",
}, },
}, },
AuthConfig: &promauth.Config{ AuthConfig: &promauth.Config{},
Authorization: "xyz abc",
},
ProxyAuthConfig: &promauth.Config{}, ProxyAuthConfig: &promauth.Config{},
jobNameOriginal: "asdf", jobNameOriginal: "asdf",
}, },
@ -1196,9 +1175,6 @@ scrape_configs:
f(` f(`
scrape_configs: scrape_configs:
- job_name: foo - job_name: foo
basic_auth:
username: xyz
password_file: testdata/password.txt
static_configs: static_configs:
- targets: ["foo.bar:1234"] - targets: ["foo.bar:1234"]
`, []*ScrapeWork{ `, []*ScrapeWork{
@ -1228,9 +1204,7 @@ scrape_configs:
Value: "foo", Value: "foo",
}, },
}, },
AuthConfig: &promauth.Config{ AuthConfig: &promauth.Config{},
Authorization: "Basic eHl6OnNlY3JldC1wYXNz",
},
ProxyAuthConfig: &promauth.Config{}, ProxyAuthConfig: &promauth.Config{},
jobNameOriginal: "foo", jobNameOriginal: "foo",
}, },
@ -1238,7 +1212,6 @@ scrape_configs:
f(` f(`
scrape_configs: scrape_configs:
- job_name: foo - job_name: foo
bearer_token_file: testdata/password.txt
static_configs: static_configs:
- targets: ["foo.bar:1234"] - targets: ["foo.bar:1234"]
`, []*ScrapeWork{ `, []*ScrapeWork{
@ -1268,9 +1241,7 @@ scrape_configs:
Value: "foo", Value: "foo",
}, },
}, },
AuthConfig: &promauth.Config{ AuthConfig: &promauth.Config{},
Authorization: "Bearer secret-pass",
},
ProxyAuthConfig: &promauth.Config{}, ProxyAuthConfig: &promauth.Config{},
jobNameOriginal: "foo", jobNameOriginal: "foo",
}, },

View file

@ -155,12 +155,12 @@ func (aw *apiWatcher) getScrapeWorkObjects() []interface{} {
} }
// groupWatcher watches for Kubernetes objects on the given apiServer with the given namespaces, // groupWatcher watches for Kubernetes objects on the given apiServer with the given namespaces,
// selectors and authorization using the given client. // selectors using the given client.
type groupWatcher struct { type groupWatcher struct {
apiServer string apiServer string
namespaces []string namespaces []string
selectors []Selector selectors []Selector
authorization string getAuthHeader func() string
client *http.Client client *http.Client
mu sync.Mutex mu sync.Mutex
@ -184,7 +184,7 @@ func newGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string,
} }
return &groupWatcher{ return &groupWatcher{
apiServer: apiServer, apiServer: apiServer,
authorization: ac.Authorization, getAuthHeader: ac.GetAuthHeader,
namespaces: namespaces, namespaces: namespaces,
selectors: selectors, selectors: selectors,
client: client, client: client,
@ -296,8 +296,8 @@ func (gw *groupWatcher) doRequest(requestURL string) (*http.Response, error) {
if err != nil { if err != nil {
logger.Fatalf("cannot create a request for %q: %s", requestURL, err) logger.Fatalf("cannot create a request for %q: %s", requestURL, err)
} }
if gw.authorization != "" { if ah := gw.getAuthHeader(); ah != "" {
req.Header.Set("Authorization", gw.authorization) req.Header.Set("Authorization", ah)
} }
return gw.client.Do(req) return gw.client.Do(req)
} }

View file

@ -42,10 +42,10 @@ type Client struct {
apiServer string apiServer string
hostPort string hostPort string
authHeader string getAuthHeader func() string
proxyAuthHeader string getProxyAuthHeader func() string
sendFullURL bool sendFullURL bool
} }
// NewClient returns new Client for the given args. // NewClient returns new Client for the given args.
@ -70,7 +70,7 @@ func NewClient(apiServer string, ac *promauth.Config, proxyURL proxy.URL, proxyA
tlsCfg = ac.NewTLSConfig() tlsCfg = ac.NewTLSConfig()
} }
sendFullURL := !isTLS && proxyURL.IsHTTPOrHTTPS() sendFullURL := !isTLS && proxyURL.IsHTTPOrHTTPS()
proxyAuthHeader := "" getProxyAuthHeader := func() string { return "" }
if sendFullURL { if sendFullURL {
// Send full urls in requests to a proxy host for non-TLS apiServer // Send full urls in requests to a proxy host for non-TLS apiServer
// like net/http package from Go does. // like net/http package from Go does.
@ -81,7 +81,9 @@ func NewClient(apiServer string, ac *promauth.Config, proxyURL proxy.URL, proxyA
if isTLS { if isTLS {
tlsCfg = proxyAC.NewTLSConfig() tlsCfg = proxyAC.NewTLSConfig()
} }
proxyAuthHeader = proxyURL.GetAuthHeader(proxyAC) getProxyAuthHeader = func() string {
return proxyURL.GetAuthHeader(proxyAC)
}
proxyURL = proxy.URL{} proxyURL = proxy.URL{}
} }
if !strings.Contains(hostPort, ":") { if !strings.Contains(hostPort, ":") {
@ -120,18 +122,18 @@ func NewClient(apiServer string, ac *promauth.Config, proxyURL proxy.URL, proxyA
MaxConns: 64 * 1024, MaxConns: 64 * 1024,
Dial: dialFunc, Dial: dialFunc,
} }
authHeader := "" getAuthHeader := func() string { return "" }
if ac != nil { if ac != nil {
authHeader = ac.Authorization getAuthHeader = ac.GetAuthHeader
} }
return &Client{ return &Client{
hc: hc, hc: hc,
blockingClient: blockingClient, blockingClient: blockingClient,
apiServer: apiServer, apiServer: apiServer,
hostPort: hostPort, hostPort: hostPort,
authHeader: authHeader, getAuthHeader: getAuthHeader,
proxyAuthHeader: proxyAuthHeader, getProxyAuthHeader: getProxyAuthHeader,
sendFullURL: sendFullURL, sendFullURL: sendFullURL,
}, nil }, nil
} }
@ -188,11 +190,11 @@ func (c *Client) getAPIResponseWithParamsAndClient(client *fasthttp.HostClient,
} }
req.Header.SetHost(c.hostPort) req.Header.SetHost(c.hostPort)
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
if c.authHeader != "" { if ah := c.getAuthHeader(); ah != "" {
req.Header.Set("Authorization", c.authHeader) req.Header.Set("Authorization", ah)
} }
if c.proxyAuthHeader != "" { if ah := c.getProxyAuthHeader(); ah != "" {
req.Header.Set("Proxy-Authorization", c.proxyAuthHeader) req.Header.Set("Proxy-Authorization", ah)
} }
var resp fasthttp.Response var resp fasthttp.Response

View file

@ -64,7 +64,7 @@ func (u *URL) String() string {
func (u *URL) GetAuthHeader(ac *promauth.Config) string { func (u *URL) GetAuthHeader(ac *promauth.Config) string {
authHeader := "" authHeader := ""
if ac != nil { if ac != nil {
authHeader = ac.Authorization authHeader = ac.GetAuthHeader()
} }
if u == nil || u.url == nil { if u == nil || u.url == nil {
return authHeader return authHeader
@ -122,10 +122,6 @@ func (u *URL) NewDialFunc(ac *promauth.Config) (fasthttp.DialFunc, error) {
if pu.Scheme == "socks5" || pu.Scheme == "tls+socks5" { if pu.Scheme == "socks5" || pu.Scheme == "tls+socks5" {
return socks5DialFunc(proxyAddr, pu, tlsCfg) return socks5DialFunc(proxyAddr, pu, tlsCfg)
} }
authHeader := u.GetAuthHeader(ac)
if authHeader != "" {
authHeader = "Proxy-Authorization: " + authHeader + "\r\n"
}
dialFunc := func(addr string) (net.Conn, error) { dialFunc := func(addr string) (net.Conn, error) {
proxyConn, err := defaultDialFunc(proxyAddr) proxyConn, err := defaultDialFunc(proxyAddr)
if err != nil { if err != nil {
@ -134,6 +130,10 @@ func (u *URL) NewDialFunc(ac *promauth.Config) (fasthttp.DialFunc, error) {
if isTLS { if isTLS {
proxyConn = tls.Client(proxyConn, tlsCfg) proxyConn = tls.Client(proxyConn, tlsCfg)
} }
authHeader := u.GetAuthHeader(ac)
if authHeader != "" {
authHeader = "Proxy-Authorization: " + authHeader + "\r\n"
}
conn, err := sendConnectRequest(proxyConn, proxyAddr, addr, authHeader) conn, err := sendConnectRequest(proxyConn, proxyAddr, addr, authHeader)
if err != nil { if err != nil {
_ = proxyConn.Close() _ = proxyConn.Close()