lib/promauth: refactor NewConfig in order to improve maintainability

1. Split NewConfig into smaller functions
2. Introduce Options struct for simplifying construction of the Config with various options

This commit is based on https://github.com/VictoriaMetrics/VictoriaMetrics/pull/2684
This commit is contained in:
Aliaksandr Valialkin 2022-07-04 14:27:48 +03:00
parent 611434ce81
commit 855436efd2
No known key found for this signature in database
GPG key ID: A72BEC6CD3D0DED1
6 changed files with 369 additions and 233 deletions

View file

@ -212,7 +212,15 @@ func getAuthConfig(argIdx int) (*promauth.Config, error) {
InsecureSkipVerify: tlsInsecureSkipVerify.GetOptionalArg(argIdx), InsecureSkipVerify: tlsInsecureSkipVerify.GetOptionalArg(argIdx),
} }
authCfg, err := promauth.NewConfig(".", nil, basicAuthCfg, token, tokenFile, oauth2Cfg, tlsCfg, hdrs) opts := &promauth.Options{
BasicAuth: basicAuthCfg,
BearerToken: token,
BearerTokenFile: tokenFile,
OAuth2: oauth2Cfg,
TLSConfig: tlsCfg,
Headers: hdrs,
}
authCfg, err := opts.NewConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot populate OAuth2 config for remoteWrite idx: %d, err: %w", argIdx, err) return nil, fmt.Errorf("cannot populate OAuth2 config for remoteWrite idx: %d, err: %w", argIdx, err)
} }

View file

@ -81,16 +81,16 @@ type TLSConfig struct {
MinVersion string `yaml:"min_version,omitempty"` MinVersion string `yaml:"min_version,omitempty"`
} }
// String returns human-readable representation of tlsConfig // String returns human-readable representation of tc
func (tlsConfig *TLSConfig) String() string { func (tc *TLSConfig) String() string {
if tlsConfig == nil { if tc == nil {
return "" return ""
} }
caHash := xxhash.Sum64(tlsConfig.CA) caHash := xxhash.Sum64(tc.CA)
certHash := xxhash.Sum64(tlsConfig.Cert) certHash := xxhash.Sum64(tc.Cert)
keyHash := xxhash.Sum64(tlsConfig.Key) keyHash := xxhash.Sum64(tc.Key)
return fmt.Sprintf("hash(ca)=%d, ca_file=%q, hash(cert)=%d, cert_file=%q, hash(key)=%d, key_file=%q, server_name=%q, insecure_skip_verify=%v, min_version=%q", return fmt.Sprintf("hash(ca)=%d, ca_file=%q, hash(cert)=%d, cert_file=%q, hash(key)=%d, key_file=%q, server_name=%q, insecure_skip_verify=%v, min_version=%q",
caHash, tlsConfig.CAFile, certHash, tlsConfig.CertFile, keyHash, tlsConfig.KeyFile, tlsConfig.ServerName, tlsConfig.InsecureSkipVerify, tlsConfig.MinVersion) caHash, tc.CAFile, certHash, tc.CertFile, keyHash, tc.KeyFile, tc.ServerName, tc.InsecureSkipVerify, tc.MinVersion)
} }
// Authorization represents generic authorization config. // Authorization represents generic authorization config.
@ -198,7 +198,11 @@ func newOAuth2ConfigInternal(baseDir string, o *OAuth2Config) (*oauth2ConfigInte
} }
oi.cfg.ClientSecret = secret oi.cfg.ClientSecret = secret
} }
ac, err := o.NewConfig(baseDir) opts := &Options{
BaseDir: baseDir,
TLSConfig: o.TLSConfig,
}
ac, err := opts.NewConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot initialize TLS config for OAuth2: %w", err) return nil, fmt.Errorf("cannot initialize TLS config for OAuth2: %w", err)
} }
@ -400,220 +404,325 @@ func (ac *Config) NewTLSConfig() *tls.Config {
// NewConfig creates auth config for the given hcc. // NewConfig creates auth config for the given hcc.
func (hcc *HTTPClientConfig) NewConfig(baseDir string) (*Config, error) { func (hcc *HTTPClientConfig) NewConfig(baseDir string) (*Config, error) {
return NewConfig(baseDir, hcc.Authorization, hcc.BasicAuth, hcc.BearerToken.String(), hcc.BearerTokenFile, hcc.OAuth2, hcc.TLSConfig, hcc.Headers) opts := &Options{
BaseDir: baseDir,
Authorization: hcc.Authorization,
BasicAuth: hcc.BasicAuth,
BearerToken: hcc.BearerToken.String(),
BearerTokenFile: hcc.BearerTokenFile,
OAuth2: hcc.OAuth2,
TLSConfig: hcc.TLSConfig,
Headers: hcc.Headers,
}
return opts.NewConfig()
} }
// NewConfig creates auth config for the given pcc. // NewConfig creates auth config for the given pcc.
func (pcc *ProxyClientConfig) NewConfig(baseDir string) (*Config, error) { func (pcc *ProxyClientConfig) NewConfig(baseDir string) (*Config, error) {
return NewConfig(baseDir, pcc.Authorization, pcc.BasicAuth, pcc.BearerToken.String(), pcc.BearerTokenFile, pcc.OAuth2, pcc.TLSConfig, pcc.Headers) opts := &Options{
} BaseDir: baseDir,
Authorization: pcc.Authorization,
// NewConfig creates auth config for the given o. BasicAuth: pcc.BasicAuth,
func (o *OAuth2Config) NewConfig(baseDir string) (*Config, error) { BearerToken: pcc.BearerToken.String(),
return NewConfig(baseDir, nil, nil, "", "", nil, o.TLSConfig, nil) BearerTokenFile: pcc.BearerTokenFile,
OAuth2: pcc.OAuth2,
TLSConfig: pcc.TLSConfig,
Headers: pcc.Headers,
}
return opts.NewConfig()
} }
// NewConfig creates auth config for the given ba. // NewConfig creates auth config for the given ba.
func (ba *BasicAuthConfig) NewConfig(baseDir string) (*Config, error) { func (ba *BasicAuthConfig) NewConfig(baseDir string) (*Config, error) {
return NewConfig(baseDir, nil, ba, "", "", nil, nil, nil) opts := &Options{
BaseDir: baseDir,
BasicAuth: ba,
}
return opts.NewConfig()
} }
// NewConfig creates auth config from the given args. // Options contain options, which must be passed to NewConfig.
// type Options struct {
// headers must be in the form 'HeaderName: header value' // BaseDir is an optional path to a base directory for resolving
func NewConfig(baseDir string, az *Authorization, basicAuth *BasicAuthConfig, bearerToken, bearerTokenFile string, o *OAuth2Config, tlsConfig *TLSConfig, headers []string) (*Config, error) { // relative filepaths in various config options.
var getAuthHeader func() string //
authDigest := "" // It is set to the current directory by default.
if az != nil { BaseDir string
azType := "Bearer"
if az.Type != "" { // Authorization contains optional Authorization.
azType = az.Type Authorization *Authorization
}
if az.CredentialsFile != "" { // BasicAuth contains optional BasicAuthConfig.
if az.Credentials != nil { BasicAuth *BasicAuthConfig
return nil, fmt.Errorf("both `credentials`=%q and `credentials_file`=%q are set", az.Credentials, az.CredentialsFile)
} // BearerToken contains optional bearer token.
filePath := fs.GetFilepath(baseDir, az.CredentialsFile) BearerToken string
getAuthHeader = func() string {
token, err := readPasswordFromFile(filePath) // BearerTokenFile contains optional path to a file with bearer token.
if err != nil { BearerTokenFile string
logger.Errorf("cannot read credentials from `credentials_file`=%q: %s", az.CredentialsFile, err)
return "" // OAuth2 contains optional OAuth2Config.
} OAuth2 *OAuth2Config
return azType + " " + token
} // TLSconfig contains optional TLSConfig.
authDigest = fmt.Sprintf("custom(type=%q, credsFile=%q)", az.Type, filePath) TLSConfig *TLSConfig
} else {
getAuthHeader = func() string { // Headers contains optional http request headers in the form 'Foo: bar'.
return azType + " " + az.Credentials.String() Headers []string
} }
authDigest = fmt.Sprintf("custom(type=%q, creds=%q)", az.Type, az.Credentials)
} // NewConfig creates auth config from the given opts.
func (opts *Options) NewConfig() (*Config, error) {
baseDir := opts.BaseDir
if baseDir == "" {
baseDir = "."
} }
if basicAuth != nil { var actx authContext
if getAuthHeader != nil { if opts.Authorization != nil {
return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`") if err := actx.initFromAuthorization(baseDir, opts.Authorization); err != nil {
}
if basicAuth.Username == "" {
return nil, fmt.Errorf("missing `username` in `basic_auth` section")
}
if basicAuth.PasswordFile != "" {
if basicAuth.Password != nil {
return nil, fmt.Errorf("both `password`=%q and `password_file`=%q are set in `basic_auth` section", basicAuth.Password, basicAuth.PasswordFile)
}
filePath := fs.GetFilepath(baseDir, basicAuth.PasswordFile)
getAuthHeader = func() string {
password, err := readPasswordFromFile(filePath)
if err != nil {
logger.Errorf("cannot read password from `password_file`=%q set in `basic_auth` section: %s", basicAuth.PasswordFile, err)
return ""
}
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := basicAuth.Username + ":" + password
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
}
authDigest = fmt.Sprintf("basic(username=%q, passwordFile=%q)", basicAuth.Username, filePath)
} else {
getAuthHeader = func() string {
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := basicAuth.Username + ":" + basicAuth.Password.String()
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
}
authDigest = fmt.Sprintf("basic(username=%q, password=%q)", basicAuth.Username, basicAuth.Password)
}
}
if bearerTokenFile != "" {
if getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token_file`")
}
if bearerToken != "" {
return nil, fmt.Errorf("both `bearer_token`=%q and `bearer_token_file`=%q are set", bearerToken, bearerTokenFile)
}
filePath := fs.GetFilepath(baseDir, bearerTokenFile)
getAuthHeader = func() string {
token, err := readPasswordFromFile(filePath)
if err != nil {
logger.Errorf("cannot read bearer token from `bearer_token_file`=%q: %s", bearerTokenFile, err)
return ""
}
return "Bearer " + token
}
authDigest = fmt.Sprintf("bearer(tokenFile=%q)", filePath)
}
if bearerToken != "" {
if getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`")
}
getAuthHeader = func() string {
return "Bearer " + bearerToken
}
authDigest = fmt.Sprintf("bearer(token=%q)", bearerToken)
}
if o != nil {
if getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth, `bearer_token` and `ouath2`")
}
oi, err := newOAuth2ConfigInternal(baseDir, o)
if err != nil {
return nil, err return nil, err
} }
getAuthHeader = func() string {
ts, err := oi.getTokenSource()
if err != nil {
logger.Errorf("cannot get OAuth2 tokenSource: %s", err)
return ""
}
t, err := ts.Token()
if err != nil {
logger.Errorf("cannot get OAuth2 token: %s", err)
return ""
}
return t.Type() + " " + t.AccessToken
}
authDigest = fmt.Sprintf("oauth2(%s)", o.String())
} }
var tlsRootCA *x509.CertPool if opts.BasicAuth != nil {
var getTLSCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error) if actx.getAuthHeader != nil {
tlsCertDigest := "" return nil, fmt.Errorf("cannot use both `authorization` and `basic_auth`")
tlsServerName := ""
tlsInsecureSkipVerify := false
tlsMinVersion := uint16(0)
if tlsConfig != nil {
tlsServerName = tlsConfig.ServerName
tlsInsecureSkipVerify = tlsConfig.InsecureSkipVerify
if len(tlsConfig.Key) != 0 || len(tlsConfig.Cert) != 0 {
cert, err := tls.X509KeyPair(tlsConfig.Cert, tlsConfig.Key)
if err != nil {
return nil, fmt.Errorf("cannot load TLS certificate from the provided `cert` and `key` values: %w", err)
}
getTLSCert = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cert, nil
}
h := xxhash.Sum64(tlsConfig.Key) ^ xxhash.Sum64(tlsConfig.Cert)
tlsCertDigest = fmt.Sprintf("digest(key+cert)=%d", h)
} else if tlsConfig.CertFile != "" || tlsConfig.KeyFile != "" {
getTLSCert = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Re-read TLS certificate from disk. This is needed for https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1420
certPath := fs.GetFilepath(baseDir, tlsConfig.CertFile)
keyPath := fs.GetFilepath(baseDir, tlsConfig.KeyFile)
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("cannot load TLS certificate from `cert_file`=%q, `key_file`=%q: %w", tlsConfig.CertFile, tlsConfig.KeyFile, err)
}
return &cert, nil
}
// Check whether the configured TLS cert can be loaded.
if _, err := getTLSCert(nil); err != nil {
return nil, err
}
tlsCertDigest = fmt.Sprintf("certFile=%q, keyFile=%q", tlsConfig.CertFile, tlsConfig.KeyFile)
} }
if len(tlsConfig.CA) != 0 { if err := actx.initFromBasicAuthConfig(baseDir, opts.BasicAuth); err != nil {
tlsRootCA = x509.NewCertPool() return nil, err
if !tlsRootCA.AppendCertsFromPEM(tlsConfig.CA) {
return nil, fmt.Errorf("cannot parse data from `ca` value")
}
} else if tlsConfig.CAFile != "" {
path := fs.GetFilepath(baseDir, tlsConfig.CAFile)
data, err := fs.ReadFileOrHTTP(path)
if err != nil {
return nil, fmt.Errorf("cannot read `ca_file` %q: %w", tlsConfig.CAFile, err)
}
tlsRootCA = x509.NewCertPool()
if !tlsRootCA.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("cannot parse data from `ca_file` %q", tlsConfig.CAFile)
}
}
if tlsConfig.MinVersion != "" {
v, err := parseTLSVersion(tlsConfig.MinVersion)
if err != nil {
return nil, fmt.Errorf("cannot parse `min_version`: %w", err)
}
tlsMinVersion = v
} }
} }
parsedHeaders, err := parseHeaders(headers) if opts.BearerTokenFile != "" {
if actx.getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token_file`")
}
if opts.BearerToken != "" {
return nil, fmt.Errorf("both `bearer_token`=%q and `bearer_token_file`=%q are set", opts.BearerToken, opts.BearerTokenFile)
}
if err := actx.initFromBearerTokenFile(baseDir, opts.BearerTokenFile); err != nil {
return nil, err
}
}
if opts.BearerToken != "" {
if actx.getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth` and `bearer_token`")
}
if err := actx.initFromBearerToken(opts.BearerToken); err != nil {
return nil, err
}
}
if opts.OAuth2 != nil {
if actx.getAuthHeader != nil {
return nil, fmt.Errorf("cannot simultaneously use `authorization`, `basic_auth, `bearer_token` and `ouath2`")
}
if err := actx.initFromOAuth2Config(baseDir, opts.OAuth2); err != nil {
return nil, err
}
}
var tctx tlsContext
if opts.TLSConfig != nil {
if err := tctx.initFromTLSConfig(baseDir, opts.TLSConfig); err != nil {
return nil, err
}
}
headers, err := parseHeaders(opts.Headers)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ac := &Config{ ac := &Config{
TLSRootCA: tlsRootCA, TLSRootCA: tctx.rootCA,
TLSServerName: tlsServerName, TLSServerName: tctx.serverName,
TLSInsecureSkipVerify: tlsInsecureSkipVerify, TLSInsecureSkipVerify: tctx.insecureSkipVerify,
TLSMinVersion: tlsMinVersion, TLSMinVersion: tctx.minVersion,
getTLSCert: getTLSCert, getTLSCert: tctx.getTLSCert,
tlsCertDigest: tlsCertDigest, tlsCertDigest: tctx.tlsCertDigest,
getAuthHeader: getAuthHeader, getAuthHeader: actx.getAuthHeader,
headers: parsedHeaders, headers: headers,
authDigest: authDigest, authDigest: actx.authDigest,
} }
return ac, nil return ac, nil
} }
type authContext struct {
// getAuthHeader must return <value> for 'Authorization: <value>' http request header
getAuthHeader func() string
// authDigest must contain the digest for the used authorization
// The digest must be changed whenever the original config changes.
authDigest string
}
func (actx *authContext) initFromAuthorization(baseDir string, az *Authorization) error {
azType := "Bearer"
if az.Type != "" {
azType = az.Type
}
if az.CredentialsFile == "" {
actx.getAuthHeader = func() string {
return azType + " " + az.Credentials.String()
}
actx.authDigest = fmt.Sprintf("custom(type=%q, creds=%q)", az.Type, az.Credentials)
return nil
}
if az.Credentials != nil {
return fmt.Errorf("both `credentials`=%q and `credentials_file`=%q are set", az.Credentials, az.CredentialsFile)
}
filePath := fs.GetFilepath(baseDir, az.CredentialsFile)
actx.getAuthHeader = func() string {
token, err := readPasswordFromFile(filePath)
if err != nil {
logger.Errorf("cannot read credentials from `credentials_file`=%q: %s", az.CredentialsFile, err)
return ""
}
return azType + " " + token
}
actx.authDigest = fmt.Sprintf("custom(type=%q, credsFile=%q)", az.Type, filePath)
return nil
}
func (actx *authContext) initFromBasicAuthConfig(baseDir string, ba *BasicAuthConfig) error {
if ba.Username == "" {
return fmt.Errorf("missing `username` in `basic_auth` section")
}
if ba.PasswordFile == "" {
actx.getAuthHeader = func() string {
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := ba.Username + ":" + ba.Password.String()
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
}
actx.authDigest = fmt.Sprintf("basic(username=%q, password=%q)", ba.Username, ba.Password)
return nil
}
if ba.Password != nil {
return fmt.Errorf("both `password`=%q and `password_file`=%q are set in `basic_auth` section", ba.Password, ba.PasswordFile)
}
filePath := fs.GetFilepath(baseDir, ba.PasswordFile)
actx.getAuthHeader = func() string {
password, err := readPasswordFromFile(filePath)
if err != nil {
logger.Errorf("cannot read password from `password_file`=%q set in `basic_auth` section: %s", ba.PasswordFile, err)
return ""
}
// See https://en.wikipedia.org/wiki/Basic_access_authentication
token := ba.Username + ":" + password
token64 := base64.StdEncoding.EncodeToString([]byte(token))
return "Basic " + token64
}
actx.authDigest = fmt.Sprintf("basic(username=%q, passwordFile=%q)", ba.Username, filePath)
return nil
}
func (actx *authContext) initFromBearerTokenFile(baseDir string, bearerTokenFile string) error {
filePath := fs.GetFilepath(baseDir, bearerTokenFile)
actx.getAuthHeader = func() string {
token, err := readPasswordFromFile(filePath)
if err != nil {
logger.Errorf("cannot read bearer token from `bearer_token_file`=%q: %s", bearerTokenFile, err)
return ""
}
return "Bearer " + token
}
actx.authDigest = fmt.Sprintf("bearer(tokenFile=%q)", filePath)
return nil
}
func (actx *authContext) initFromBearerToken(bearerToken string) error {
actx.getAuthHeader = func() string {
return "Bearer " + bearerToken
}
actx.authDigest = fmt.Sprintf("bearer(token=%q)", bearerToken)
return nil
}
func (actx *authContext) initFromOAuth2Config(baseDir string, o *OAuth2Config) error {
oi, err := newOAuth2ConfigInternal(baseDir, o)
if err != nil {
return err
}
actx.getAuthHeader = func() string {
ts, err := oi.getTokenSource()
if err != nil {
logger.Errorf("cannot get OAuth2 tokenSource: %s", err)
return ""
}
t, err := ts.Token()
if err != nil {
logger.Errorf("cannot get OAuth2 token: %s", err)
return ""
}
return t.Type() + " " + t.AccessToken
}
actx.authDigest = fmt.Sprintf("oauth2(%s)", o.String())
return nil
}
type tlsContext struct {
getTLSCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
tlsCertDigest string
rootCA *x509.CertPool
serverName string
insecureSkipVerify bool
minVersion uint16
}
func (tctx *tlsContext) initFromTLSConfig(baseDir string, tc *TLSConfig) error {
tctx.serverName = tc.ServerName
tctx.insecureSkipVerify = tc.InsecureSkipVerify
if len(tc.Key) != 0 || len(tc.Cert) != 0 {
cert, err := tls.X509KeyPair(tc.Cert, tc.Key)
if err != nil {
return fmt.Errorf("cannot load TLS certificate from the provided `cert` and `key` values: %w", err)
}
tctx.getTLSCert = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cert, nil
}
h := xxhash.Sum64(tc.Key) ^ xxhash.Sum64(tc.Cert)
tctx.tlsCertDigest = fmt.Sprintf("digest(key+cert)=%d", h)
} else if tc.CertFile != "" || tc.KeyFile != "" {
tctx.getTLSCert = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Re-read TLS certificate from disk. This is needed for https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1420
certPath := fs.GetFilepath(baseDir, tc.CertFile)
keyPath := fs.GetFilepath(baseDir, tc.KeyFile)
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("cannot load TLS certificate from `cert_file`=%q, `key_file`=%q: %w", tc.CertFile, tc.KeyFile, err)
}
return &cert, nil
}
// Check whether the configured TLS cert can be loaded.
if _, err := tctx.getTLSCert(nil); err != nil {
return err
}
tctx.tlsCertDigest = fmt.Sprintf("certFile=%q, keyFile=%q", tc.CertFile, tc.KeyFile)
}
if len(tc.CA) != 0 {
tctx.rootCA = x509.NewCertPool()
if !tctx.rootCA.AppendCertsFromPEM(tc.CA) {
return fmt.Errorf("cannot parse data from `ca` value")
}
} else if tc.CAFile != "" {
path := fs.GetFilepath(baseDir, tc.CAFile)
data, err := fs.ReadFileOrHTTP(path)
if err != nil {
return fmt.Errorf("cannot read `ca_file` %q: %w", tc.CAFile, err)
}
tctx.rootCA = x509.NewCertPool()
if !tctx.rootCA.AppendCertsFromPEM(data) {
return fmt.Errorf("cannot parse data from `ca_file` %q", tc.CAFile)
}
}
if tc.MinVersion != "" {
v, err := parseTLSVersion(tc.MinVersion)
if err != nil {
return fmt.Errorf("cannot parse `min_version`: %w", err)
}
tctx.minVersion = v
}
return nil
}
func parseTLSVersion(s string) (uint16, error) { func parseTLSVersion(s string) (uint16, error) {
switch strings.ToUpper(s) { switch strings.ToUpper(s) {
case "TLS13": case "TLS13":

View file

@ -9,25 +9,16 @@ import (
) )
func TestNewConfig(t *testing.T) { func TestNewConfig(t *testing.T) {
type args struct {
baseDir string
az *Authorization
basicAuth *BasicAuthConfig
bearerToken string
bearerTokenFile string
oauth *OAuth2Config
tlsConfig *TLSConfig
}
tests := []struct { tests := []struct {
name string name string
args args opts Options
wantErr bool wantErr bool
expectHeader string expectHeader string
}{ }{
{ {
name: "OAuth2 config", name: "OAuth2 config",
args: args{ opts: Options{
oauth: &OAuth2Config{ OAuth2: &OAuth2Config{
ClientID: "some-id", ClientID: "some-id",
ClientSecret: NewSecret("some-secret"), ClientSecret: NewSecret("some-secret"),
TokenURL: "http://localhost:8511", TokenURL: "http://localhost:8511",
@ -37,8 +28,8 @@ func TestNewConfig(t *testing.T) {
}, },
{ {
name: "OAuth2 config with file", name: "OAuth2 config with file",
args: args{ opts: Options{
oauth: &OAuth2Config{ OAuth2: &OAuth2Config{
ClientID: "some-id", ClientID: "some-id",
ClientSecretFile: "testdata/test_secretfile.txt", ClientSecretFile: "testdata/test_secretfile.txt",
TokenURL: "http://localhost:8511", TokenURL: "http://localhost:8511",
@ -48,8 +39,8 @@ func TestNewConfig(t *testing.T) {
}, },
{ {
name: "OAuth2 want err", name: "OAuth2 want err",
args: args{ opts: Options{
oauth: &OAuth2Config{ OAuth2: &OAuth2Config{
ClientID: "some-id", ClientID: "some-id",
ClientSecret: NewSecret("some-secret"), ClientSecret: NewSecret("some-secret"),
ClientSecretFile: "testdata/test_secretfile.txt", ClientSecretFile: "testdata/test_secretfile.txt",
@ -60,8 +51,8 @@ func TestNewConfig(t *testing.T) {
}, },
{ {
name: "basic Auth config", name: "basic Auth config",
args: args{ opts: Options{
basicAuth: &BasicAuthConfig{ BasicAuth: &BasicAuthConfig{
Username: "user", Username: "user",
Password: NewSecret("password"), Password: NewSecret("password"),
}, },
@ -70,8 +61,8 @@ func TestNewConfig(t *testing.T) {
}, },
{ {
name: "basic Auth config with file", name: "basic Auth config with file",
args: args{ opts: Options{
basicAuth: &BasicAuthConfig{ BasicAuth: &BasicAuthConfig{
Username: "user", Username: "user",
PasswordFile: "testdata/test_secretfile.txt", PasswordFile: "testdata/test_secretfile.txt",
}, },
@ -80,8 +71,8 @@ func TestNewConfig(t *testing.T) {
}, },
{ {
name: "want Authorization", name: "want Authorization",
args: args{ opts: Options{
az: &Authorization{ Authorization: &Authorization{
Type: "Bearer", Type: "Bearer",
Credentials: NewSecret("Value"), Credentials: NewSecret("Value"),
}, },
@ -90,16 +81,16 @@ func TestNewConfig(t *testing.T) {
}, },
{ {
name: "token file", name: "token file",
args: args{ opts: Options{
bearerTokenFile: "testdata/test_secretfile.txt", BearerTokenFile: "testdata/test_secretfile.txt",
}, },
expectHeader: "Bearer secret-content", expectHeader: "Bearer secret-content",
}, },
{ {
name: "token with tls", name: "token with tls",
args: args{ opts: Options{
bearerToken: "some-token", BearerToken: "some-token",
tlsConfig: &TLSConfig{ TLSConfig: &TLSConfig{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}, },
}, },
@ -108,7 +99,7 @@ func TestNewConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if tt.args.oauth != nil { if tt.opts.OAuth2 != nil {
r := http.NewServeMux() r := http.NewServeMux()
r.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { r.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -116,9 +107,9 @@ func TestNewConfig(t *testing.T) {
}) })
mock := httptest.NewServer(r) mock := httptest.NewServer(r)
tt.args.oauth.TokenURL = mock.URL tt.opts.OAuth2.TokenURL = mock.URL
} }
got, err := NewConfig(tt.args.baseDir, tt.args.az, tt.args.basicAuth, tt.args.bearerToken, tt.args.bearerTokenFile, tt.args.oauth, tt.args.tlsConfig, nil) got, err := tt.opts.NewConfig()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NewConfig() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewConfig() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -140,7 +131,6 @@ func TestNewConfig(t *testing.T) {
t.Fatalf("unexpected auth header from fasthttp request; got %q; want %q", ahb, tt.expectHeader) t.Fatalf("unexpected auth header from fasthttp request; got %q; want %q", ahb, tt.expectHeader)
} }
} }
}) })
} }
} }
@ -186,7 +176,10 @@ func TestConfigHeaders(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("cannot parse headers: %s", err) t.Fatalf("cannot parse headers: %s", err)
} }
c, err := NewConfig("", nil, nil, "", "", nil, nil, headers) opts := Options{
Headers: headers,
}
c, err := opts.NewConfig()
if err != nil { if err != nil {
t.Fatalf("cannot create config: %s", err) t.Fatalf("cannot create config: %s", err)
} }

View file

@ -1656,11 +1656,17 @@ scrape_configs:
}, },
}) })
ac, err := promauth.NewConfig(".", nil, nil, "", "", nil, nil, []string{"My-Auth: foo-Bar"}) opts := &promauth.Options{
Headers: []string{"My-Auth: foo-Bar"},
}
ac, err := opts.NewConfig()
if err != nil { if err != nil {
t.Fatalf("unexpected error when creating promauth.Config: %s", err) t.Fatalf("unexpected error when creating promauth.Config: %s", err)
} }
proxyAC, err := promauth.NewConfig(".", nil, nil, "", "", nil, nil, []string{"Foo:bar"}) opts = &promauth.Options{
Headers: []string{"Foo:bar"},
}
proxyAC, err := opts.NewConfig()
if err != nil { if err != nil {
t.Fatalf("unexpected error when creating promauth.Config for proxy: %s", err) t.Fatalf("unexpected error when creating promauth.Config for proxy: %s", err)
} }

View file

@ -31,7 +31,16 @@ func newAPIConfig(sdc *SDConfig, baseDir string, swcFunc ScrapeWorkConstructorFu
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot build kube config from the specified `kubeconfig_file` config option: %w", err) return nil, fmt.Errorf("cannot build kube config from the specified `kubeconfig_file` config option: %w", err)
} }
acNew, err := promauth.NewConfig(".", nil, kc.basicAuth, kc.token, kc.tokenFile, cc.OAuth2, kc.tlsConfig, cc.Headers) opts := &promauth.Options{
BaseDir: baseDir,
BasicAuth: kc.basicAuth,
BearerToken: kc.token,
BearerTokenFile: kc.tokenFile,
OAuth2: cc.OAuth2,
TLSConfig: kc.tlsConfig,
Headers: cc.Headers,
}
acNew, err := opts.NewConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot initialize auth config from `kubeconfig_file: %q`: %w", sdc.KubeConfigFile, err) return nil, fmt.Errorf("cannot initialize auth config from `kubeconfig_file: %q`: %w", sdc.KubeConfigFile, err)
} }
@ -58,7 +67,14 @@ func newAPIConfig(sdc *SDConfig, baseDir string, swcFunc ScrapeWorkConstructorFu
tlsConfig := promauth.TLSConfig{ tlsConfig := promauth.TLSConfig{
CAFile: "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt", CAFile: "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt",
} }
acNew, err := promauth.NewConfig(".", nil, nil, "", "/var/run/secrets/kubernetes.io/serviceaccount/token", cc.OAuth2, &tlsConfig, cc.Headers) opts := &promauth.Options{
BaseDir: baseDir,
BearerTokenFile: "/var/run/secrets/kubernetes.io/serviceaccount/token",
OAuth2: cc.OAuth2,
TLSConfig: &tlsConfig,
Headers: cc.Headers,
}
acNew, err := opts.NewConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot initialize service account auth: %w; probably, `kubernetes_sd_config->api_server` is missing in Prometheus configs?", err) return nil, fmt.Errorf("cannot initialize service account auth: %w; probably, `kubernetes_sd_config->api_server` is missing in Prometheus configs?", err)
} }

View file

@ -81,7 +81,11 @@ func newAPIConfig(sdc *SDConfig, baseDir string) (*apiConfig, error) {
port: sdc.Port, port: sdc.Port,
} }
if sdc.TLSConfig != nil { if sdc.TLSConfig != nil {
ac, err := promauth.NewConfig(baseDir, nil, nil, "", "", nil, sdc.TLSConfig, nil) opts := &promauth.Options{
BaseDir: baseDir,
TLSConfig: sdc.TLSConfig,
}
ac, err := opts.NewConfig()
if err != nil { if err != nil {
return nil, err return nil, err
} }