diff --git a/app/vmagent/remotewrite/client.go b/app/vmagent/remotewrite/client.go index 14ef10f58..84e4d46b8 100644 --- a/app/vmagent/remotewrite/client.go +++ b/app/vmagent/remotewrite/client.go @@ -113,17 +113,12 @@ func newHTTPClient(argIdx int, remoteWriteURL, sanitizedURL string, fq *persiste if err != nil { logger.Fatalf("cannot initialize auth config for -remoteWrite.url=%q: %s", remoteWriteURL, err) } - tlsCfg, err := authCfg.NewTLSConfig() - if err != nil { - logger.Fatalf("cannot initialize tls config for -remoteWrite.url=%q: %s", remoteWriteURL, err) - } awsCfg, err := getAWSAPIConfig(argIdx) if err != nil { logger.Fatalf("cannot initialize AWS Config for -remoteWrite.url=%q: %s", remoteWriteURL, err) } tr := &http.Transport{ DialContext: statDial, - TLSClientConfig: tlsCfg, TLSHandshakeTimeout: tlsHandshakeTimeout.GetOptionalArg(argIdx), MaxConnsPerHost: 2 * concurrency, MaxIdleConnsPerHost: 2 * concurrency, @@ -142,7 +137,7 @@ func newHTTPClient(argIdx int, remoteWriteURL, sanitizedURL string, fq *persiste tr.Proxy = http.ProxyURL(pu) } hc := &http.Client{ - Transport: tr, + Transport: authCfg.NewRoundTripper(tr), Timeout: sendTimeout.GetOptionalArg(argIdx), } c := &client{ diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index a30004c3f..f56065883 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -56,6 +56,7 @@ See also [LTS releases](https://docs.victoriametrics.com/lts-releases/). * FEATURE: [vmagent](https://docs.victoriametrics.com/vmagent.html): ability to limit the ingestion rate via `-maxIngestionRate` command-line flag. See [this pull request](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/5900). * FEATURE: [vmagent](https://docs.victoriametrics.com/vmagent.html): use the provided `-remoteWrite.tlsServerName` as `Host` header in requests to `-remoteWrite.url`. This allows sending data to https remote storage by IP address instead of hostname. Thanks to @minor-fixes for initial idea and [the pull request](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/5802). * FEATURE: [vmagent](https://docs.victoriametrics.com/vmagent.html): add `-remoteWrite.shardByURL.ignoreLabels` command-line flag, which can be used for specifying the ignored list of labels when [sharding by `-remoteWrite.url` is enabled](https://docs.victoriametrics.com/vmagent/#sharding-among-remote-storages). Thanks to @edma2 for the idea and [the pull request](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/5938). +* FEATURE: [vmagent](https://docs.victoriametrics.com/vmagent.html): automatically reload updated root CA certificates from files without the need to restart `vmagent`. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5526). * FEATURE: optimize [`/api/v1/labels`](https://docs.victoriametrics.com/url-examples/#apiv1labels) and [`/api/v1/label/.../values`](https://docs.victoriametrics.com/url-examples/#apiv1labelvalues) when `match[]` filters contains metric name. For example, `/api/v1/label/instance/values?match[]=up` now works much faster than before. See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5055). * FEATURE: [vmctl](https://docs.victoriametrics.com/vmctl.html): support client-side TLS configuration for [native protocol](https://docs.victoriametrics.com/vmctl/#migrating-data-from-victoriametrics). See [this feature request](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5748). Thanks to @khushijain21 for the [pull request](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/5824). * FEATURE: [vmctl](https://docs.victoriametrics.com/vmctl.html): support client-side TLS configuration for VictoriaMetrics destination specified via `--vm-*` cmd-line flags used in [InfluxDB](https://docs.victoriametrics.com/vmctl/#migrating-data-from-influxdb-1x), [Remote Read protocol](https://docs.victoriametrics.com/vmctl/#migrating-data-by-remote-read-protocol), [OpenTSDB](https://docs.victoriametrics.com/vmctl/#migrating-data-from-opentsdb), [Prometheus](https://docs.victoriametrics.com/vmctl/#migrating-data-from-prometheus) and [Promscale](https://docs.victoriametrics.com/vmctl/#migrating-data-from-promscale) migration modes. @@ -65,7 +66,6 @@ See also [LTS releases](https://docs.victoriametrics.com/lts-releases/). * FEATURE: [OpenTelemetry](https://docs.victoriametrics.com/#sending-data-via-opentelemetry): add `-opentelemetry.usePrometheusNaming` command-line flag, which can be used for enabling automatic conversion of the ingested metric names and labels into Prometheus-compatible format. See [these docs](https://docs.victoriametrics.com/#sending-data-via-opentelemetry) and [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/6037). * BUGFIX: prevent from automatic deletion of newly registered time series when it is queried immediately after the addition. The probability of this bug has been increased significantly after [v1.99.0](https://github.com/VictoriaMetrics/VictoriaMetrics/releases/tag/v1.99.0) because of optimizations related to registering new time series. See [this](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5948) and [this](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5959) issue. -* BUGFIX: properly reload root CA certificates from files for [service discovery](https://docs.victoriametrics.com/sd_configs/) and [scraping requests in stream parsing mode](https://docs.victoriametrics.com/vmagent/?highlight=stream&highlight=parse#stream-parsing-mode). See [this issue](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5526). * BUGFIX: [vmagent](https://docs.victoriametrics.com/vmagent.html): properly set `Host` header in requests to scrape targets if it is specified via [`headers` option](https://docs.victoriametrics.com/sd_configs/#http-api-client-options). Thanks to @fholzer for [the bugreport](https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5969) and [the fix](https://github.com/VictoriaMetrics/VictoriaMetrics/pull/5970). * BUGFIX: [vmagent](https://docs.victoriametrics.com/vmagent.html): properly set `Host` header in requests to scrape targets when [`server_name` option at `tls_config`](https://prometheus.io/docs/prometheus/latest/configuration/configuration/#tls_config) is set. Previously the `Host` header was set incorrectly to the target hostname in this case. * BUGFIX: do not drop `match[]` filter at [`/api/v1/series`](https://docs.victoriametrics.com/url-examples/#apiv1series) if `-search.ignoreExtraFiltersAtLabelsAPI` command-line flag is set, since missing `match[]` filter breaks `/api/v1/series` requests. diff --git a/lib/promauth/config.go b/lib/promauth/config.go index bdd70ab68..8cba947f3 100644 --- a/lib/promauth/config.go +++ b/lib/promauth/config.go @@ -1,7 +1,6 @@ package promauth import ( - "bytes" "context" "crypto/tls" "crypto/x509" @@ -18,7 +17,6 @@ import ( "github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime" "github.com/VictoriaMetrics/VictoriaMetrics/lib/fs/fscore" - "github.com/VictoriaMetrics/VictoriaMetrics/lib/logger" "github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil" ) @@ -235,15 +233,10 @@ func urlValuesFromMap(m map[string]string) url.Values { } func (oi *oauth2ConfigInternal) initTokenSource() error { - tlsCfg, err := oi.ac.NewTLSConfig() - if err != nil { - return fmt.Errorf("cannot initialize TLS config for OAuth2: %w", err) - } c := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsCfg, - Proxy: oi.proxyURLFunc, - }, + Transport: oi.ac.NewRoundTripper(&http.Transport{ + Proxy: oi.proxyURLFunc, + }), } oi.ctx = context.WithValue(context.Background(), oauth2.HTTPClient, c) oi.tokenSource = oi.cfg.TokenSource(oi.ctx) @@ -281,8 +274,10 @@ type Config struct { tlsInsecureSkipVerify bool tlsMinVersion uint16 - getTLSRootCACached getTLSRootCAFunc - tlsRootCADigest string + getTLSConfigCached getTLSConfigFunc + + getTLSRootCA getTLSRootCAFunc + tlsRootCADigest string getTLSCertCached getTLSCertFunc tlsCertDigest string @@ -292,8 +287,6 @@ type Config struct { headers []keyValue headersDigest string - - tctx *tlsContext } type keyValue struct { @@ -376,8 +369,6 @@ func (ac *Config) String() string { ac.authHeaderDigest, ac.headersDigest, ac.tlsRootCADigest, ac.tlsCertDigest, ac.tlsServerName, ac.tlsInsecureSkipVerify, ac.tlsMinVersion) } -const tlsCertsCacheSeconds = 1 - // getAuthHeaderFunc must return for 'Authorization: ' http request header type getAuthHeaderFunc func() (string, error) @@ -397,7 +388,7 @@ func newGetAuthHeaderCached(getAuthHeader getAuthHeaderFunc) getAuthHeaderFunc { defer mu.Unlock() if fasttime.UnixTimestamp() > deadline { ah, err = getAuthHeader() - deadline = fasttime.UnixTimestamp() + tlsCertsCacheSeconds + deadline = fasttime.UnixTimestamp() + 1 } return ah, err } @@ -405,24 +396,22 @@ func newGetAuthHeaderCached(getAuthHeader getAuthHeaderFunc) getAuthHeaderFunc { type getTLSRootCAFunc func() (*x509.CertPool, error) -func newGetTLSRootCACached(getTLSRootCA getTLSRootCAFunc) getTLSRootCAFunc { - if getTLSRootCA == nil { - return nil - } +type getTLSConfigFunc func() (*tls.Config, error) + +func newGetTLSConfigCached(getTLSConfig getTLSConfigFunc) getTLSConfigFunc { var mu sync.Mutex var deadline uint64 - var rootCA *x509.CertPool + var tlsCfg *tls.Config var err error - return func() (*x509.CertPool, error) { - // Cache the root CA and the error for up to a second in order to save CPU time - // on reading and parsing the root CA from files. + return func() (*tls.Config, error) { + // Cache the tlsCfg and the error for up to a second in order to save CPU time on getTLSConfig() call. mu.Lock() defer mu.Unlock() if fasttime.UnixTimestamp() > deadline { - rootCA, err = getTLSRootCA() - deadline = fasttime.UnixTimestamp() + tlsCertsCacheSeconds + tlsCfg, err = getTLSConfig() + deadline = fasttime.UnixTimestamp() + 1 } - return rootCA, err + return tlsCfg, err } } @@ -443,52 +432,103 @@ func newGetTLSCertCached(getTLSCert getTLSCertFunc) getTLSCertFunc { defer mu.Unlock() if fasttime.UnixTimestamp() > deadline { cert, err = getTLSCert(cri) - deadline = fasttime.UnixTimestamp() + tlsCertsCacheSeconds + deadline = fasttime.UnixTimestamp() + 1 } return cert, err } } -// NewTLSConfig returns new TLS config for the given ac. -func (ac *Config) NewTLSConfig() (*tls.Config, error) { +// NewRoundTripper returns new http.RoundTripper for the given ac, which uses the given trBase as base transport. +// +// The caller shouldn't change the trBase, since the returned RoundTripper owns it. +func (ac *Config) NewRoundTripper(trBase *http.Transport) http.RoundTripper { + rt := &roundTripper{ + trBase: trBase, + } + if ac != nil { + rt.getTLSConfigCached = ac.getTLSConfigCached + } + return rt +} + +type roundTripper struct { + trBase *http.Transport + getTLSConfigCached getTLSConfigFunc + + rootCAPrev *x509.CertPool + trPrev *http.Transport + mu sync.Mutex +} + +// RoundTrip implements http.RoundTripper interface. +func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + tr, err := rt.getTransport() + if err != nil { + return nil, fmt.Errorf("cannot initialize Transport: %w", err) + } + return tr.RoundTrip(req) +} + +func (rt *roundTripper) getTransport() (*http.Transport, error) { + if rt.getTLSConfigCached == nil { + return rt.trBase, nil + } + + tlsCfg, err := rt.getTLSConfigCached() + if err != nil { + return nil, fmt.Errorf("cannot initialize TLS config: %w", err) + } + + rt.mu.Lock() + defer rt.mu.Unlock() + + if rt.trPrev != nil && tlsCfg.RootCAs.Equal(rt.rootCAPrev) { + // Fast path - tlsCfg wasn't changed. Return the previously created transport. + return rt.trPrev, nil + } + + // Slow path - tlsCfg has been changed. + // Close connections for the previous transport and create new transport for the updated tlsCfg. + if rt.trPrev != nil { + rt.trPrev.CloseIdleConnections() + } + + tr := rt.trBase.Clone() + tr.TLSClientConfig = tlsCfg + rt.trPrev = tr + rt.rootCAPrev = tlsCfg.RootCAs + + return rt.trPrev, nil +} + +func (ac *Config) getTLSConfig() (*tls.Config, error) { + if ac.getTLSCertCached == nil && ac.tlsServerName == "" && !ac.tlsInsecureSkipVerify && ac.tlsMinVersion == 0 && ac.getTLSRootCA == nil { + // Re-use zeroTLSConfig when ac doesn't contain tls-specific configs. + // This should reduce memory usage a bit. + return zeroTLSConfig, nil + } + tlsCfg := &tls.Config{ - ClientSessionCache: tls.NewLRUClientSessionCache(0), + ClientSessionCache: tls.NewLRUClientSessionCache(0), + GetClientCertificate: ac.getTLSCertCached, + ServerName: ac.tlsServerName, + InsecureSkipVerify: ac.tlsInsecureSkipVerify, + MinVersion: ac.tlsMinVersion, + // Do not set MaxVersion, since this has no sense from security PoV. + // This can only result in lower security level if improperly configured. } - if ac == nil { - return tlsCfg, nil - } - tlsCfg.GetClientCertificate = ac.getTLSCertCached - if f := ac.getTLSRootCACached; f != nil { + if f := ac.getTLSRootCA; f != nil { rootCA, err := f() if err != nil { return nil, fmt.Errorf("cannot load root CAs: %w", err) } tlsCfg.RootCAs = rootCA } - tlsCfg.ServerName = ac.tlsServerName - tlsCfg.InsecureSkipVerify = ac.tlsInsecureSkipVerify - tlsCfg.MinVersion = ac.tlsMinVersion - // Do not set tlsCfg.MaxVersion, since this has no sense from security PoV. - // This can only result in lower security level if improperly set. return tlsCfg, nil } -// NewRoundTripper returns new http.RoundTripper for the given ac. -func (ac *Config) NewRoundTripper(builder func(*http.Transport)) (http.RoundTripper, error) { - cfg, err := ac.NewTLSConfig() - if err != nil { - return nil, fmt.Errorf("failed to initialize TLS config: %w", err) - } - - if ac == nil { - tr := &http.Transport{ - TLSClientConfig: cfg, - } - builder(tr) - return tr, nil - } - - return ac.tctx.NewTLSRoundTripper(cfg, builder) +var zeroTLSConfig = &tls.Config{ + ClientSessionCache: tls.NewLRUClientSessionCache(0), } // NewConfig creates auth config for the given hcc. @@ -627,8 +667,8 @@ func (opts *Options) NewConfig() (*Config, error) { tlsInsecureSkipVerify: tctx.insecureSkipVerify, tlsMinVersion: tctx.minVersion, - getTLSRootCACached: newGetTLSRootCACached(tctx.getTLSRootCA), - tlsRootCADigest: tctx.tlsRootCADigest, + getTLSRootCA: tctx.getTLSRootCA, + tlsRootCADigest: tctx.tlsRootCADigest, getTLSCertCached: newGetTLSCertCached(tctx.getTLSCert), tlsCertDigest: tctx.tlsCertDigest, @@ -638,9 +678,8 @@ func (opts *Options) NewConfig() (*Config, error) { headers: headers, headersDigest: headersDigest, - - tctx: &tctx, } + ac.getTLSConfigCached = newGetTLSConfigCached(ac.getTLSConfig) return ac, nil } @@ -778,7 +817,6 @@ type tlsContext struct { tlsCertDigest string getTLSRootCA getTLSRootCAFunc - getRootCAPEM func() ([]byte, error) tlsRootCADigest string serverName string @@ -802,26 +840,16 @@ func (tctx *tlsContext) initFromTLSConfig(baseDir string, tc *TLSConfig) error { } else if tc.CertFile != "" || tc.KeyFile != "" { certPath := fscore.GetFilepath(baseDir, tc.CertFile) keyPath := fscore.GetFilepath(baseDir, tc.KeyFile) - getCertsPEM := func() ([]byte, []byte, error) { + tctx.getTLSCert = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { // Re-read TLS certificate from disk. This is needed for https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1420 certData, err := fscore.ReadFileOrHTTP(certPath) if err != nil { - return nil, nil, fmt.Errorf("cannot read TLS certificate from %q: %w", certPath, err) + return nil, fmt.Errorf("cannot read TLS certificate from %q: %w", certPath, err) } keyData, err := fscore.ReadFileOrHTTP(keyPath) if err != nil { - return nil, nil, fmt.Errorf("cannot read TLS key from %q: %w", keyPath, err) + return nil, fmt.Errorf("cannot read TLS key from %q: %w", keyPath, err) } - - return certData, keyData, nil - } - - tctx.getTLSCert = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - certData, keyData, err := getCertsPEM() - if err != nil { - return nil, err - } - cert, err := tls.X509KeyPair(certData, keyData) if err != nil { return nil, fmt.Errorf("cannot load TLS certificate from `cert_file`=%q, `key_file`=%q: %w", tc.CertFile, tc.KeyFile, err) @@ -835,48 +863,24 @@ func (tctx *tlsContext) initFromTLSConfig(baseDir string, tc *TLSConfig) error { if !rootCA.AppendCertsFromPEM([]byte(tc.CA)) { return fmt.Errorf("cannot parse data from `ca` value") } - tctx.getTLSRootCA = func() (*x509.CertPool, error) { return rootCA, nil } - - tctx.getRootCAPEM = func() ([]byte, error) { - return []byte(tc.CA), nil - } - h := xxhash.Sum64([]byte(tc.CA)) tctx.tlsRootCADigest = fmt.Sprintf("digest(CA)=%d", h) } else if tc.CAFile != "" { path := fscore.GetFilepath(baseDir, tc.CAFile) - - getRootCAPEM := func() ([]byte, error) { + tctx.getTLSRootCA = func() (*x509.CertPool, error) { data, err := fscore.ReadFileOrHTTP(path) if err != nil { return nil, fmt.Errorf("cannot read `ca_file`: %w", err) } - return data, nil - } - - tctx.getTLSRootCA = func() (*x509.CertPool, error) { - data, err := getRootCAPEM() - if err != nil { - return nil, err - } rootCA := x509.NewCertPool() if !rootCA.AppendCertsFromPEM(data) { return nil, fmt.Errorf("cannot parse data read from `ca_file` %q", tc.CAFile) } return rootCA, nil } - tctx.getRootCAPEM = func() ([]byte, error) { - data, err := getRootCAPEM() - if err != nil { - return nil, err - } - return data, nil - } - // Does not hash file contents, since they may change at any time. - // TLSRoundTripper must be used to automatically update the root CA. tctx.tlsRootCADigest = fmt.Sprintf("caFile=%q", tc.CAFile) } v, err := netutil.ParseTLSVersion(tc.MinVersion) @@ -886,118 +890,3 @@ func (tctx *tlsContext) initFromTLSConfig(baseDir string, tc *TLSConfig) error { tctx.minVersion = v return nil } - -// NewTLSRoundTripper returns new http.RoundTripper which automatically updates -// RootCA in tls.Config whenever it changes. -func (tctx *tlsContext) NewTLSRoundTripper(cfg *tls.Config, builder func(transport *http.Transport)) (http.RoundTripper, error) { - // TLS context is not initialized so use the provided RoundTripper without wrapper - if tctx == nil || tctx.getRootCAPEM == nil { - tr := &http.Transport{ - TLSClientConfig: cfg, - } - builder(tr) - return tr, nil - } - - var deadline uint64 - var rootCA []byte - var err error - getTLSDigestsLocked := func() []byte { - if fasttime.UnixTimestamp() > deadline { - rootCA, err = tctx.getRootCAPEM() - if err != nil { - logger.Warnf("cannot load root CA: %s", err) - } - deadline = fasttime.UnixTimestamp() + tlsCertsCacheSeconds - } - - return rootCA - } - tr := &http.Transport{ - TLSClientConfig: cfg, - } - builder(tr) - - return &TLSRoundTripper{ - builder: builder, - getRootCABytesLocked: getTLSDigestsLocked, - getTLSRootCA: newGetTLSRootCACached(tctx.getTLSRootCA), - config: cfg, - tr: tr, - }, nil -} - -var _ http.RoundTripper = &TLSRoundTripper{} - -// TLSRoundTripper is an implementation of http.RoundTripper which automatically -// updates RootCA in tls.Config whenever it changes. -type TLSRoundTripper struct { - builder func(*http.Transport) - getRootCABytesLocked func() []byte - getTLSRootCA func() (*x509.CertPool, error) - - config *tls.Config - - tr http.RoundTripper - - m sync.Mutex - rootCABytes []byte -} - -// RoundTrip implements http.RoundTripper.RoundTrip. -// It automatically updates RootCA in tls.Config whenever it changes. -func (t *TLSRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - t.m.Lock() - rootCABytes := t.getRootCABytesLocked() - equal := bytes.Equal(rootCABytes, t.rootCABytes) - - if equal { - t.m.Unlock() - return t.tr.RoundTrip(request) - } - - err := t.recreateRoundTripperLocked(rootCABytes) - t.m.Unlock() - - if err != nil { - return nil, fmt.Errorf("cannot recreate RoundTripper: %w", err) - } - - return t.tr.RoundTrip(request) -} - -func (t *TLSRoundTripper) recreateRoundTripperLocked(rootCABytes []byte) error { - newRootCaPool, err := t.getTLSRootCA() - if err != nil { - return fmt.Errorf("cannot load root CAs: %w", err) - } - newTLSConfig := t.config.Clone() - newTLSConfig.RootCAs = newRootCaPool - // Reset ClientSessionCache, since it may contain sessions for the old root CA. - newTLSConfig.ClientSessionCache = tls.NewLRUClientSessionCache(0) - - tr := &http.Transport{ - TLSClientConfig: newTLSConfig, - } - t.builder(tr) - - oldRt := t.tr - t.tr = tr - t.rootCABytes = rootCABytes - t.config = newTLSConfig - - if oldRt != nil { - closeIdleConnections(oldRt) - } - return nil -} - -type closeIdler interface { - CloseIdleConnections() -} - -func closeIdleConnections(t http.RoundTripper) { - if ci, ok := t.(closeIdler); ok { - ci.CloseIdleConnections() - } -} diff --git a/lib/promauth/config_test.go b/lib/promauth/config_test.go index 58ce0882f..cdeb41e00 100644 --- a/lib/promauth/config_test.go +++ b/lib/promauth/config_test.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "math/big" "net" "net/http" @@ -607,35 +608,30 @@ func TestConfigHeaders(t *testing.T) { func TestTLSConfigWithCertificatesFilesUpdate(t *testing.T) { // Generate and save a self-signed CA certificate and a certificate signed by the CA - caPEM, certPEM, keyPEM := generateCertificates(t) + caPEM, certPEM, keyPEM := mustGenerateCertificates() _ = os.WriteFile("testdata/ca.pem", caPEM, 0644) - _ = os.WriteFile("testdata/cert.pem", certPEM, 0644) - _ = os.WriteFile("testdata/key.pem", keyPEM, 0644) defer func() { - for _, p := range []string{ - "testdata/ca.pem", - "testdata/cert.pem", - "testdata/key.pem", - } { - _ = os.Remove(p) - } + _ = os.Remove("testdata/ca.pem") }() - cert, err := tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem") + cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { t.Fatalf("cannot load generated certificate: %s", err) } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } - tlsConfig := &tls.Config{} - tlsConfig.Certificates = []tls.Certificate{cert} - - s := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) s.TLS = tlsConfig s.StartTLS() - serverURL, _ := url.Parse(s.URL) + serverURL, err := url.Parse(s.URL) + if err != nil { + t.Fatalf("unexpected error when parsing url=%q: %s", s.URL, err) + } opts := Options{ TLSConfig: &TLSConfig{ @@ -646,13 +642,9 @@ func TestTLSConfigWithCertificatesFilesUpdate(t *testing.T) { if err != nil { t.Fatalf("unexpected error when parsing config: %s", err) } - tr, err := ac.NewRoundTripper(func(tr *http.Transport) {}) - if err != nil { - t.Fatalf("unexpected error when creating roundtripper: %s", err) - } client := http.Client{ - Transport: tr, + Transport: ac.NewRoundTripper(&http.Transport{}), } resp, err := client.Do(&http.Request{ @@ -662,17 +654,16 @@ func TestTLSConfigWithCertificatesFilesUpdate(t *testing.T) { if err != nil { t.Fatalf("unexpected error when making request: %s", err) } - if resp.StatusCode != http.StatusOK { t.Fatalf("expected status code %d; got %d", http.StatusOK, resp.StatusCode) } // Update CA file with new CA and get config - ca2PEM, _, _ := generateCertificates(t) + ca2PEM, _, _ := mustGenerateCertificates() _ = os.WriteFile("testdata/ca.pem", ca2PEM, 0644) // Wait for cert cache expiration - time.Sleep(2 * tlsCertsCacheSeconds * time.Second) + time.Sleep(2 * time.Second) _, err = client.Do(&http.Request{ Method: http.MethodGet, @@ -683,7 +674,7 @@ func TestTLSConfigWithCertificatesFilesUpdate(t *testing.T) { } } -func generateCertificates(t *testing.T) ([]byte, []byte, []byte) { +func mustGenerateCertificates() ([]byte, []byte, []byte) { // Small key size for faster tests const testCertificateBits = 1024 @@ -701,11 +692,11 @@ func generateCertificates(t *testing.T) ([]byte, []byte, []byte) { } caPrivKey, err := rsa.GenerateKey(rand.Reader, testCertificateBits) if err != nil { - t.Fatal(err) + panic(fmt.Errorf("cannot generate CA private key: %s", err)) } caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) if err != nil { - t.Fatal(err) + panic(fmt.Errorf("cannot create CA certificate: %s", err)) } caPEM := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", @@ -727,11 +718,11 @@ func generateCertificates(t *testing.T) ([]byte, []byte, []byte) { } key, err := rsa.GenerateKey(rand.Reader, testCertificateBits) if err != nil { - t.Fatal(err) + panic(fmt.Errorf("cannot generate certificate private key: %s", err)) } certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &key.PublicKey, caPrivKey) if err != nil { - t.Fatal(err) + panic(fmt.Errorf("cannot generate certificate: %s", err)) } certPEM := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", diff --git a/lib/promscrape/client.go b/lib/promscrape/client.go index 25d23e41d..e5f7b4a5e 100644 --- a/lib/promscrape/client.go +++ b/lib/promscrape/client.go @@ -2,7 +2,6 @@ package promscrape import ( "context" - "crypto/tls" "flag" "fmt" "io" @@ -43,23 +42,18 @@ type client struct { } func newClient(ctx context.Context, sw *ScrapeWork) (*client, error) { - isTLS := strings.HasPrefix(sw.ScrapeURL, "https://") + ac := sw.AuthConfig setHeaders := func(req *http.Request) error { return sw.AuthConfig.SetHeaders(req, true) } setProxyHeaders := func(_ *http.Request) error { return nil } - var tlsCfg *tls.Config proxyURL := sw.ProxyURL - if !isTLS && proxyURL.IsHTTPOrHTTPS() { + if !strings.HasPrefix(sw.ScrapeURL, "https://") && proxyURL.IsHTTPOrHTTPS() { pu := proxyURL.GetURL() if pu.Scheme == "https" { - var err error - tlsCfg, err = sw.ProxyAuthConfig.NewTLSConfig() - if err != nil { - return nil, fmt.Errorf("cannot initialize proxy tls config: %w", err) - } + ac = sw.ProxyAuthConfig } setProxyHeaders = func(req *http.Request) error { return proxyURL.SetHeaders(sw.ProxyAuthConfig, req) @@ -69,30 +63,19 @@ func newClient(ctx context.Context, sw *ScrapeWork) (*client, error) { if pu := sw.ProxyURL.GetURL(); pu != nil { proxyURLFunc = http.ProxyURL(pu) } - - rt, err := sw.AuthConfig.NewRoundTripper(func(tr *http.Transport) { - if !isTLS && proxyURL.IsHTTPOrHTTPS() { - tr.TLSClientConfig = tlsCfg - } - - tr.Proxy = proxyURLFunc - tr.TLSHandshakeTimeout = 10 * time.Second - tr.IdleConnTimeout = 2 * sw.ScrapeInterval - tr.DisableCompression = *disableCompression || sw.DisableCompression - tr.DisableKeepAlives = *disableKeepAlive || sw.DisableKeepAlive - tr.DialContext = statStdDial - tr.MaxIdleConnsPerHost = 100 - tr.MaxResponseHeaderBytes = int64(maxResponseHeadersSize.N) - }) - if err != nil { - return nil, fmt.Errorf("cannot initialize tls config: %w", err) - } - hc := &http.Client{ - Transport: rt, - Timeout: sw.ScrapeTimeout, + Transport: ac.NewRoundTripper(&http.Transport{ + Proxy: proxyURLFunc, + TLSHandshakeTimeout: 10 * time.Second, + IdleConnTimeout: 2 * sw.ScrapeInterval, + DisableCompression: *disableCompression || sw.DisableCompression, + DisableKeepAlives: *disableKeepAlive || sw.DisableKeepAlive, + DialContext: statStdDial, + MaxIdleConnsPerHost: 100, + MaxResponseHeaderBytes: int64(maxResponseHeadersSize.N), + }), + Timeout: sw.ScrapeTimeout, } - if sw.DenyRedirects { hc.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse diff --git a/lib/promscrape/config.go b/lib/promscrape/config.go index dc4e6169d..9c5b106f4 100644 --- a/lib/promscrape/config.go +++ b/lib/promscrape/config.go @@ -211,19 +211,7 @@ func areEqualGlobalConfigs(a, b *GlobalConfig) bool { func areEqualScrapeConfigs(a, b *ScrapeConfig) bool { sa := a.marshalJSON() sb := b.marshalJSON() - if string(sa) != string(sb) { - return false - } - // Compare auth configs for a and b, since they may differ by TLS CA file contents, - // which is missing in the marshaled JSON of a and b, - // but it existis in the string representation of auth configs. - if a.swc.authConfig.String() != b.swc.authConfig.String() { - return false - } - if a.swc.proxyAuthConfig.String() != b.swc.proxyAuthConfig.String() { - return false - } - return true + return string(sa) == string(sb) } func (sc *ScrapeConfig) unmarshalJSON(data []byte) error { diff --git a/lib/promscrape/discovery/kubernetes/api_watcher.go b/lib/promscrape/discovery/kubernetes/api_watcher.go index fb6a98957..ba92bdab0 100644 --- a/lib/promscrape/discovery/kubernetes/api_watcher.go +++ b/lib/promscrape/discovery/kubernetes/api_watcher.go @@ -88,10 +88,7 @@ func newAPIWatcher(apiServer string, ac *promauth.Config, sdc *SDConfig, swcFunc attachNodeMetadata = sdc.AttachMetadata.Node } proxyURL := sdc.ProxyURL.GetURL() - gw, err := getGroupWatcher(apiServer, ac, namespaces, selectors, attachNodeMetadata, proxyURL) - if err != nil { - return nil, err - } + gw := getGroupWatcher(apiServer, ac, namespaces, selectors, attachNodeMetadata, proxyURL) role := sdc.role() aw := &apiWatcher{ role: role, @@ -246,25 +243,19 @@ type groupWatcher struct { noAPIWatchers bool } -func newGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string, selectors []Selector, attachNodeMetadata bool, proxyURL *url.URL) (*groupWatcher, error) { +func newGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string, selectors []Selector, attachNodeMetadata bool, proxyURL *url.URL) *groupWatcher { var proxy func(*http.Request) (*url.URL, error) if proxyURL != nil { proxy = http.ProxyURL(proxyURL) } - tr, err := ac.NewRoundTripper(func(tr *http.Transport) { - tr.Proxy = proxy - tr.TLSHandshakeTimeout = 10 * time.Second - tr.IdleConnTimeout = *apiServerTimeout - tr.MaxIdleConnsPerHost = 100 - }) - - if err != nil { - return nil, fmt.Errorf("cannot initialize tls config: %w", err) - } - client := &http.Client{ - Transport: tr, - Timeout: *apiServerTimeout, + Transport: ac.NewRoundTripper(&http.Transport{ + Proxy: proxy, + TLSHandshakeTimeout: 10 * time.Second, + IdleConnTimeout: *apiServerTimeout, + MaxIdleConnsPerHost: 100, + }), + Timeout: *apiServerTimeout, } ctx, cancel := context.WithCancel(context.Background()) gw := &groupWatcher{ @@ -282,10 +273,10 @@ func newGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string, ctx: ctx, cancel: cancel, } - return gw, nil + return gw } -func getGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string, selectors []Selector, attachNodeMetadata bool, proxyURL *url.URL) (*groupWatcher, error) { +func getGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string, selectors []Selector, attachNodeMetadata bool, proxyURL *url.URL) *groupWatcher { proxyURLStr := "" if proxyURL != nil { proxyURLStr = proxyURL.String() @@ -294,17 +285,12 @@ func getGroupWatcher(apiServer string, ac *promauth.Config, namespaces []string, apiServer, namespaces, selectorsKey(selectors), attachNodeMetadata, proxyURLStr, ac.String()) groupWatchersLock.Lock() gw := groupWatchers[key] - var err error if gw == nil { - gw, err = newGroupWatcher(apiServer, ac, namespaces, selectors, attachNodeMetadata, proxyURL) - if err != nil { - err = fmt.Errorf("cannot initialize watcher for key={%s}: %w", key, err) - } else { - groupWatchers[key] = gw - } + gw = newGroupWatcher(apiServer, ac, namespaces, selectors, attachNodeMetadata, proxyURL) + groupWatchers[key] = gw } groupWatchersLock.Unlock() - return gw, err + return gw } func selectorsKey(selectors []Selector) string { diff --git a/lib/promscrape/discovery/openstack/api.go b/lib/promscrape/discovery/openstack/api.go index 7d50e79b9..14b6bb6a4 100644 --- a/lib/promscrape/discovery/openstack/api.go +++ b/lib/promscrape/discovery/openstack/api.go @@ -94,14 +94,9 @@ func newAPIConfig(sdc *SDConfig, baseDir string) (*apiConfig, error) { cfg.client.CloseIdleConnections() return nil, fmt.Errorf("cannot parse TLS config: %w", err) } - tr, err := ac.NewRoundTripper(func(tr *http.Transport) { - tr.MaxIdleConnsPerHost = 100 + cfg.client.Transport = ac.NewRoundTripper(&http.Transport{ + MaxIdleConnsPerHost: 100, }) - if err != nil { - cfg.client.CloseIdleConnections() - return nil, fmt.Errorf("cannot initialize TLS config: %w", err) - } - cfg.client.Transport = tr } // use public compute endpoint by default if len(cfg.availability) == 0 { diff --git a/lib/promscrape/discovery/yandexcloud/api.go b/lib/promscrape/discovery/yandexcloud/api.go index 3d6c88969..149aaa40c 100644 --- a/lib/promscrape/discovery/yandexcloud/api.go +++ b/lib/promscrape/discovery/yandexcloud/api.go @@ -47,9 +47,10 @@ func getAPIConfig(sdc *SDConfig, baseDir string) (*apiConfig, error) { } func newAPIConfig(sdc *SDConfig, baseDir string) (*apiConfig, error) { - var transport http.RoundTripper = &http.Transport{ + tr := &http.Transport{ MaxIdleConnsPerHost: 100, } + rt := http.RoundTripper(tr) if sdc.TLSConfig != nil { opts := &promauth.Options{ BaseDir: baseDir, @@ -59,16 +60,11 @@ func newAPIConfig(sdc *SDConfig, baseDir string) (*apiConfig, error) { if err != nil { return nil, fmt.Errorf("cannot parse TLS config: %w", err) } - transport, err = ac.NewRoundTripper(func(tr *http.Transport) { - tr.MaxIdleConnsPerHost = 100 - }) - if err != nil { - return nil, fmt.Errorf("cannot initialize TLS config: %w", err) - } + rt = ac.NewRoundTripper(tr) } cfg := &apiConfig{ client: &http.Client{ - Transport: transport, + Transport: rt, }, } apiEndpoint := sdc.APIEndpoint diff --git a/lib/promscrape/discoveryutils/client.go b/lib/promscrape/discoveryutils/client.go index c20b8a47c..1ee73ca00 100644 --- a/lib/promscrape/discoveryutils/client.go +++ b/lib/promscrape/discoveryutils/client.go @@ -111,35 +111,25 @@ func NewClient(apiServer string, ac *promauth.Config, proxyURL *proxy.URL, proxy proxyURLFunc = http.ProxyURL(pu) } - tr, err := ac.NewRoundTripper(func(tr *http.Transport) { - tr.Proxy = proxyURLFunc - tr.TLSHandshakeTimeout = 10 * time.Second - tr.MaxIdleConnsPerHost = *maxConcurrency - tr.ResponseHeaderTimeout = DefaultClientReadTimeout - tr.DialContext = dialFunc - }) - if err != nil { - return nil, fmt.Errorf("cannot initialize tls config: %w", err) - } - - blockingTR, err := ac.NewRoundTripper(func(tr *http.Transport) { - tr.Proxy = proxyURLFunc - tr.TLSHandshakeTimeout = 10 * time.Second - tr.MaxIdleConnsPerHost = 1000 - tr.ResponseHeaderTimeout = BlockingClientReadTimeout - tr.DialContext = dialFunc - }) - if err != nil { - return nil, fmt.Errorf("cannot initialize tls config: %w", err) - } - client := &http.Client{ - Timeout: DefaultClientReadTimeout, - Transport: tr, + Timeout: DefaultClientReadTimeout, + Transport: ac.NewRoundTripper(&http.Transport{ + Proxy: proxyURLFunc, + TLSHandshakeTimeout: 10 * time.Second, + MaxIdleConnsPerHost: *maxConcurrency, + ResponseHeaderTimeout: DefaultClientReadTimeout, + DialContext: dialFunc, + }), } blockingClient := &http.Client{ - Timeout: BlockingClientReadTimeout, - Transport: blockingTR, + Timeout: BlockingClientReadTimeout, + Transport: ac.NewRoundTripper(&http.Transport{ + Proxy: proxyURLFunc, + TLSHandshakeTimeout: 10 * time.Second, + MaxIdleConnsPerHost: 1000, + ResponseHeaderTimeout: BlockingClientReadTimeout, + DialContext: dialFunc, + }), } setHTTPHeaders := func(_ *http.Request) error { return nil }