mirror of
https://github.com/VictoriaMetrics/VictoriaMetrics.git
synced 2024-11-21 14:44:00 +00:00
add2db12b2
The Request.Body for requests without body can be nil. This could break readTrackingBody.Read() logic, which could incorrectly return "cannot read data after closing the reader" error in this case. Fix this by initializing the readTrackingBody.r with zeroReader. While at it, properly set Host header if it is specified in 'headers' section. It must be set net/http.Request.Host instead of net/http.Request.Header.Set(), since the net/http.Client overwrites the Host header with the value from req.Host before sending the request. While at it, add tests for requestHandler(). Additional tests for various requestHandler() cases will be added in future commits. Updates https://github.com/VictoriaMetrics/VictoriaMetrics/issues/6445 Updates https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5707 Updates https://github.com/VictoriaMetrics/VictoriaMetrics/issues/5240 Updates https://github.com/VictoriaMetrics/VictoriaMetrics/pull/6525
1049 lines
31 KiB
Go
1049 lines
31 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"flag"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/VictoriaMetrics/metrics"
|
|
"github.com/cespare/xxhash/v2"
|
|
"gopkg.in/yaml.v2"
|
|
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/envtemplate"
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/fs/fscore"
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
|
|
"github.com/VictoriaMetrics/VictoriaMetrics/lib/procutil"
|
|
)
|
|
|
|
var (
|
|
authConfigPath = flag.String("auth.config", "", "Path to auth config. It can point either to local file or to http url. "+
|
|
"See https://docs.victoriametrics.com/vmauth/ for details on the format of this auth config")
|
|
configCheckInterval = flag.Duration("configCheckInterval", 0, "interval for config file re-read. "+
|
|
"Zero value disables config re-reading. By default, refreshing is disabled, send SIGHUP for config refresh.")
|
|
defaultRetryStatusCodes = flagutil.NewArrayInt("retryStatusCodes", 0, "Comma-separated list of default HTTP response status codes when vmauth re-tries the request on other backends. "+
|
|
"See https://docs.victoriametrics.com/vmauth/#load-balancing for details")
|
|
defaultLoadBalancingPolicy = flag.String("loadBalancingPolicy", "least_loaded", "The default load balancing policy to use for backend urls specified inside url_prefix section. "+
|
|
"Supported policies: least_loaded, first_available. See https://docs.victoriametrics.com/vmauth/#load-balancing")
|
|
discoverBackendIPsGlobal = flag.Bool("discoverBackendIPs", false, "Whether to discover backend IPs via periodic DNS queries to hostnames specified in url_prefix. "+
|
|
"This may be useful when url_prefix points to a hostname with dynamically scaled instances behind it. See https://docs.victoriametrics.com/vmauth/#discovering-backend-ips")
|
|
discoverBackendIPsInterval = flag.Duration("discoverBackendIPsInterval", 10*time.Second, "The interval for re-discovering backend IPs if -discoverBackendIPs command-line flag is set. "+
|
|
"Too low value may lead to DNS errors")
|
|
httpAuthHeader = flagutil.NewArrayString("httpAuthHeader", "HTTP request header to use for obtaining authorization tokens. By default auth tokens are read from Authorization request header")
|
|
)
|
|
|
|
// AuthConfig represents auth config.
|
|
type AuthConfig struct {
|
|
Users []UserInfo `yaml:"users,omitempty"`
|
|
UnauthorizedUser *UserInfo `yaml:"unauthorized_user,omitempty"`
|
|
|
|
// ms holds all the metrics for the given AuthConfig
|
|
ms *metrics.Set
|
|
}
|
|
|
|
// UserInfo is user information read from authConfigPath
|
|
type UserInfo struct {
|
|
Name string `yaml:"name,omitempty"`
|
|
|
|
BearerToken string `yaml:"bearer_token,omitempty"`
|
|
AuthToken string `yaml:"auth_token,omitempty"`
|
|
Username string `yaml:"username,omitempty"`
|
|
Password string `yaml:"password,omitempty"`
|
|
|
|
URLPrefix *URLPrefix `yaml:"url_prefix,omitempty"`
|
|
DiscoverBackendIPs *bool `yaml:"discover_backend_ips,omitempty"`
|
|
URLMaps []URLMap `yaml:"url_map,omitempty"`
|
|
HeadersConf HeadersConf `yaml:",inline"`
|
|
MaxConcurrentRequests int `yaml:"max_concurrent_requests,omitempty"`
|
|
DefaultURL *URLPrefix `yaml:"default_url,omitempty"`
|
|
RetryStatusCodes []int `yaml:"retry_status_codes,omitempty"`
|
|
LoadBalancingPolicy string `yaml:"load_balancing_policy,omitempty"`
|
|
DropSrcPathPrefixParts *int `yaml:"drop_src_path_prefix_parts,omitempty"`
|
|
TLSCAFile string `yaml:"tls_ca_file,omitempty"`
|
|
TLSCertFile string `yaml:"tls_cert_file,omitempty"`
|
|
TLSKeyFile string `yaml:"tls_key_file,omitempty"`
|
|
TLSServerName string `yaml:"tls_server_name,omitempty"`
|
|
TLSInsecureSkipVerify *bool `yaml:"tls_insecure_skip_verify,omitempty"`
|
|
|
|
MetricLabels map[string]string `yaml:"metric_labels,omitempty"`
|
|
|
|
concurrencyLimitCh chan struct{}
|
|
concurrencyLimitReached *metrics.Counter
|
|
|
|
rt http.RoundTripper
|
|
|
|
requests *metrics.Counter
|
|
backendErrors *metrics.Counter
|
|
requestsDuration *metrics.Summary
|
|
}
|
|
|
|
// HeadersConf represents config for request and response headers.
|
|
type HeadersConf struct {
|
|
RequestHeaders []*Header `yaml:"headers,omitempty"`
|
|
ResponseHeaders []*Header `yaml:"response_headers,omitempty"`
|
|
KeepOriginalHost *bool `yaml:"keep_original_host,omitempty"`
|
|
}
|
|
|
|
func (ui *UserInfo) beginConcurrencyLimit() error {
|
|
select {
|
|
case ui.concurrencyLimitCh <- struct{}{}:
|
|
return nil
|
|
default:
|
|
ui.concurrencyLimitReached.Inc()
|
|
return fmt.Errorf("cannot handle more than %d concurrent requests from user %s", ui.getMaxConcurrentRequests(), ui.name())
|
|
}
|
|
}
|
|
|
|
func (ui *UserInfo) endConcurrencyLimit() {
|
|
<-ui.concurrencyLimitCh
|
|
}
|
|
|
|
func (ui *UserInfo) getMaxConcurrentRequests() int {
|
|
mcr := ui.MaxConcurrentRequests
|
|
if mcr <= 0 {
|
|
mcr = *maxConcurrentPerUserRequests
|
|
}
|
|
return mcr
|
|
}
|
|
|
|
// Header is `Name: Value` http header, which must be added to the proxied request.
|
|
type Header struct {
|
|
Name string
|
|
Value string
|
|
|
|
sOriginal string
|
|
}
|
|
|
|
// UnmarshalYAML unmarshals h from f.
|
|
func (h *Header) UnmarshalYAML(f func(any) error) error {
|
|
var s string
|
|
if err := f(&s); err != nil {
|
|
return err
|
|
}
|
|
h.sOriginal = s
|
|
|
|
n := strings.IndexByte(s, ':')
|
|
if n < 0 {
|
|
return fmt.Errorf("missing speparator char ':' between Name and Value in the header %q; expected format - 'Name: Value'", s)
|
|
}
|
|
h.Name = strings.TrimSpace(s[:n])
|
|
h.Value = strings.TrimSpace(s[n+1:])
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML marshals h to yaml.
|
|
func (h *Header) MarshalYAML() (any, error) {
|
|
return h.sOriginal, nil
|
|
}
|
|
|
|
// URLMap is a mapping from source paths to target urls.
|
|
type URLMap struct {
|
|
// SrcPaths is an optional list of regular expressions, which must match the request path.
|
|
SrcPaths []*Regex `yaml:"src_paths,omitempty"`
|
|
|
|
// SrcHosts is an optional list of regular expressions, which must match the request hostname.
|
|
SrcHosts []*Regex `yaml:"src_hosts,omitempty"`
|
|
|
|
// SrcQueryArgs is an optional list of query args, which must match request URL query args.
|
|
SrcQueryArgs []*QueryArg `yaml:"src_query_args,omitempty"`
|
|
|
|
// SrcHeaders is an optional list of headers, which must match request headers.
|
|
SrcHeaders []*Header `yaml:"src_headers,omitempty"`
|
|
|
|
// UrlPrefix contains backend url prefixes for the proxied request url.
|
|
URLPrefix *URLPrefix `yaml:"url_prefix,omitempty"`
|
|
|
|
// DiscoverBackendIPs instructs discovering URLPrefix backend IPs via DNS.
|
|
DiscoverBackendIPs *bool `yaml:"discover_backend_ips,omitempty"`
|
|
|
|
// HeadersConf is the config for augumenting request and response headers.
|
|
HeadersConf HeadersConf `yaml:",inline"`
|
|
|
|
// RetryStatusCodes is the list of response status codes used for retrying requests.
|
|
RetryStatusCodes []int `yaml:"retry_status_codes,omitempty"`
|
|
|
|
// LoadBalancingPolicy is load balancing policy among UrlPrefix backends.
|
|
LoadBalancingPolicy string `yaml:"load_balancing_policy,omitempty"`
|
|
|
|
// DropSrcPathPrefixParts is the number of `/`-delimited request path prefix parts to drop before proxying the request to backend.
|
|
DropSrcPathPrefixParts *int `yaml:"drop_src_path_prefix_parts,omitempty"`
|
|
}
|
|
|
|
// QueryArg represents HTTP query arg
|
|
type QueryArg struct {
|
|
Name string
|
|
Value *Regex
|
|
|
|
sOriginal string
|
|
}
|
|
|
|
// UnmarshalYAML unmarshals qa from yaml.
|
|
func (qa *QueryArg) UnmarshalYAML(f func(any) error) error {
|
|
var s string
|
|
if err := f(&s); err != nil {
|
|
return err
|
|
}
|
|
qa.sOriginal = s
|
|
|
|
n := strings.IndexByte(s, '=')
|
|
if n < 0 {
|
|
return nil
|
|
}
|
|
|
|
qa.Name = s[:n]
|
|
expr := s[n+1:]
|
|
if !strings.HasPrefix(expr, "~") {
|
|
expr = regexp.QuoteMeta(expr)
|
|
} else {
|
|
expr = expr[1:]
|
|
}
|
|
|
|
var re Regex
|
|
if err := yaml.Unmarshal([]byte(expr), &re); err != nil {
|
|
return fmt.Errorf("cannot unmarshal regex for %q query arg: %w", qa.Name, err)
|
|
}
|
|
qa.Value = &re
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML marshals qa to yaml.
|
|
func (qa *QueryArg) MarshalYAML() (any, error) {
|
|
return qa.sOriginal, nil
|
|
}
|
|
|
|
// URLPrefix represents passed `url_prefix`
|
|
type URLPrefix struct {
|
|
// requests are re-tried on other backend urls for these http response status codes
|
|
retryStatusCodes []int
|
|
|
|
// load balancing policy used
|
|
loadBalancingPolicy string
|
|
|
|
// how many request path prefix parts to drop before routing the request to backendURL
|
|
dropSrcPathPrefixParts int
|
|
|
|
// busOriginal contains the original list of backends specified in yaml config.
|
|
busOriginal []*url.URL
|
|
|
|
// n is an atomic counter, which is used for balancing load among available backends.
|
|
n atomic.Uint32
|
|
|
|
// the list of backend urls
|
|
//
|
|
// the list can be dynamically updated if `discover_backend_ips` option is set.
|
|
bus atomic.Pointer[[]*backendURL]
|
|
|
|
// if this option is set, then backend ips for busOriginal are periodically re-discovered and put to bus.
|
|
discoverBackendIPs bool
|
|
|
|
// The next deadline for DNS-based discovery of backend IPs
|
|
nextDiscoveryDeadline atomic.Uint64
|
|
|
|
// vOriginal contains the original yaml value for URLPrefix.
|
|
vOriginal any
|
|
}
|
|
|
|
func (up *URLPrefix) setLoadBalancingPolicy(loadBalancingPolicy string) error {
|
|
switch loadBalancingPolicy {
|
|
case "", // empty string is equivalent to least_loaded
|
|
"least_loaded",
|
|
"first_available":
|
|
up.loadBalancingPolicy = loadBalancingPolicy
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unexpected load_balancing_policy: %q; want least_loaded or first_available", loadBalancingPolicy)
|
|
}
|
|
}
|
|
|
|
type backendURL struct {
|
|
brokenDeadline atomic.Uint64
|
|
concurrentRequests atomic.Int32
|
|
|
|
url *url.URL
|
|
}
|
|
|
|
func (bu *backendURL) isBroken() bool {
|
|
ct := fasttime.UnixTimestamp()
|
|
return ct < bu.brokenDeadline.Load()
|
|
}
|
|
|
|
func (bu *backendURL) setBroken() {
|
|
deadline := fasttime.UnixTimestamp() + uint64((*failTimeout).Seconds())
|
|
bu.brokenDeadline.Store(deadline)
|
|
}
|
|
|
|
func (bu *backendURL) get() {
|
|
bu.concurrentRequests.Add(1)
|
|
}
|
|
|
|
func (bu *backendURL) put() {
|
|
bu.concurrentRequests.Add(-1)
|
|
}
|
|
|
|
func (up *URLPrefix) getBackendsCount() int {
|
|
pbus := up.bus.Load()
|
|
return len(*pbus)
|
|
}
|
|
|
|
// getBackendURL returns the backendURL depending on the load balance policy.
|
|
//
|
|
// It can return nil if there are no backend urls available at the moment.
|
|
//
|
|
// backendURL.put() must be called on the returned backendURL after the request is complete.
|
|
func (up *URLPrefix) getBackendURL() *backendURL {
|
|
up.discoverBackendAddrsIfNeeded()
|
|
|
|
pbus := up.bus.Load()
|
|
bus := *pbus
|
|
if len(bus) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if up.loadBalancingPolicy == "first_available" {
|
|
return getFirstAvailableBackendURL(bus)
|
|
}
|
|
return getLeastLoadedBackendURL(bus, &up.n)
|
|
}
|
|
|
|
func (up *URLPrefix) discoverBackendAddrsIfNeeded() {
|
|
if !up.discoverBackendIPs {
|
|
// The discovery is disabled.
|
|
return
|
|
}
|
|
|
|
ct := fasttime.UnixTimestamp()
|
|
deadline := up.nextDiscoveryDeadline.Load()
|
|
if ct < deadline {
|
|
// There is no need in discovering backends.
|
|
return
|
|
}
|
|
|
|
intervalSec := math.Ceil(discoverBackendIPsInterval.Seconds())
|
|
if intervalSec <= 0 {
|
|
intervalSec = 1
|
|
}
|
|
nextDeadline := ct + uint64(intervalSec)
|
|
if !up.nextDiscoveryDeadline.CompareAndSwap(deadline, nextDeadline) {
|
|
// Concurrent goroutine already started the discovery.
|
|
return
|
|
}
|
|
|
|
// Discover ips for all the backendURLs
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(intervalSec))
|
|
hostToAddrs := make(map[string][]string)
|
|
for _, bu := range up.busOriginal {
|
|
host := bu.Hostname()
|
|
if hostToAddrs[host] != nil {
|
|
// ips for the given host have been already discovered
|
|
continue
|
|
}
|
|
|
|
var resolvedAddrs []string
|
|
if strings.HasPrefix(host, "srv+") {
|
|
// The host has the format 'srv+realhost'. Strip 'srv+' prefix before performing the lookup.
|
|
srvHost := strings.TrimPrefix(host, "srv+")
|
|
_, addrs, err := netutil.Resolver.LookupSRV(ctx, "", "", srvHost)
|
|
if err != nil {
|
|
logger.Warnf("cannot discover backend SRV records for %s: %s; use it literally", bu, err)
|
|
resolvedAddrs = []string{host}
|
|
} else {
|
|
resolvedAddrs = make([]string, len(addrs))
|
|
for i, addr := range addrs {
|
|
resolvedAddrs[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
|
|
}
|
|
}
|
|
} else {
|
|
addrs, err := netutil.Resolver.LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
logger.Warnf("cannot discover backend IPs for %s: %s; use it literally", bu, err)
|
|
resolvedAddrs = []string{host}
|
|
} else {
|
|
resolvedAddrs = make([]string, len(addrs))
|
|
for i, addr := range addrs {
|
|
resolvedAddrs[i] = addr.String()
|
|
}
|
|
}
|
|
}
|
|
// sort resolvedAddrs, so they could be compared below in areEqualBackendURLs()
|
|
sort.Strings(resolvedAddrs)
|
|
hostToAddrs[host] = resolvedAddrs
|
|
}
|
|
cancel()
|
|
|
|
// generate new backendURLs for the resolved IPs
|
|
var busNew []*backendURL
|
|
for _, bu := range up.busOriginal {
|
|
host := bu.Hostname()
|
|
port := bu.Port()
|
|
for _, addr := range hostToAddrs[host] {
|
|
buCopy := *bu
|
|
buCopy.Host = addr
|
|
if port != "" {
|
|
if n := strings.IndexByte(buCopy.Host, ':'); n >= 0 {
|
|
// Drop the discovered port and substitute it the port specified in bu.
|
|
buCopy.Host = buCopy.Host[:n]
|
|
}
|
|
buCopy.Host += ":" + port
|
|
}
|
|
busNew = append(busNew, &backendURL{
|
|
url: &buCopy,
|
|
})
|
|
}
|
|
}
|
|
|
|
pbus := up.bus.Load()
|
|
if areEqualBackendURLs(*pbus, busNew) {
|
|
return
|
|
}
|
|
|
|
// Store new backend urls
|
|
up.bus.Store(&busNew)
|
|
}
|
|
|
|
func areEqualBackendURLs(a, b []*backendURL) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i, aURL := range a {
|
|
bURL := b[i]
|
|
if aURL.url.String() != bURL.url.String() {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// getFirstAvailableBackendURL returns the first available backendURL, which isn't broken.
|
|
//
|
|
// backendURL.put() must be called on the returned backendURL after the request is complete.
|
|
func getFirstAvailableBackendURL(bus []*backendURL) *backendURL {
|
|
bu := bus[0]
|
|
if !bu.isBroken() {
|
|
// Fast path - send the request to the first url.
|
|
bu.get()
|
|
return bu
|
|
}
|
|
|
|
// Slow path - the first url is temporarily unavailabel. Fall back to the remaining urls.
|
|
for i := 1; i < len(bus); i++ {
|
|
if !bus[i].isBroken() {
|
|
bu = bus[i]
|
|
break
|
|
}
|
|
}
|
|
bu.get()
|
|
return bu
|
|
}
|
|
|
|
// getLeastLoadedBackendURL returns the backendURL with the minimum number of concurrent requests.
|
|
//
|
|
// backendURL.put() must be called on the returned backendURL after the request is complete.
|
|
func getLeastLoadedBackendURL(bus []*backendURL, atomicCounter *atomic.Uint32) *backendURL {
|
|
if len(bus) == 1 {
|
|
// Fast path - return the only backend url.
|
|
bu := bus[0]
|
|
bu.get()
|
|
return bu
|
|
}
|
|
|
|
// Slow path - select other backend urls.
|
|
n := atomicCounter.Add(1)
|
|
|
|
for i := uint32(0); i < uint32(len(bus)); i++ {
|
|
idx := (n + i) % uint32(len(bus))
|
|
bu := bus[idx]
|
|
if bu.isBroken() {
|
|
continue
|
|
}
|
|
if bu.concurrentRequests.Load() == 0 {
|
|
// Fast path - return the backend with zero concurrently executed requests.
|
|
// Do not use CompareAndSwap() instead of Load(), since it is much slower on systems with many CPU cores.
|
|
bu.concurrentRequests.Add(1)
|
|
return bu
|
|
}
|
|
}
|
|
|
|
// Slow path - return the backend with the minimum number of concurrently executed requests.
|
|
buMin := bus[n%uint32(len(bus))]
|
|
minRequests := buMin.concurrentRequests.Load()
|
|
for _, bu := range bus {
|
|
if bu.isBroken() {
|
|
continue
|
|
}
|
|
if n := bu.concurrentRequests.Load(); n < minRequests {
|
|
buMin = bu
|
|
minRequests = n
|
|
}
|
|
}
|
|
buMin.get()
|
|
return buMin
|
|
}
|
|
|
|
// UnmarshalYAML unmarshals up from yaml.
|
|
func (up *URLPrefix) UnmarshalYAML(f func(any) error) error {
|
|
var v any
|
|
if err := f(&v); err != nil {
|
|
return err
|
|
}
|
|
up.vOriginal = v
|
|
|
|
var urls []string
|
|
switch x := v.(type) {
|
|
case string:
|
|
urls = []string{x}
|
|
case []any:
|
|
if len(x) == 0 {
|
|
return fmt.Errorf("`url_prefix` must contain at least a single url")
|
|
}
|
|
us := make([]string, len(x))
|
|
for i, xx := range x {
|
|
s, ok := xx.(string)
|
|
if !ok {
|
|
return fmt.Errorf("`url_prefix` must contain array of strings; got %T", xx)
|
|
}
|
|
us[i] = s
|
|
}
|
|
urls = us
|
|
default:
|
|
return fmt.Errorf("unexpected type for `url_prefix`: %T; want string or []string", v)
|
|
}
|
|
|
|
bus := make([]*url.URL, len(urls))
|
|
for i, u := range urls {
|
|
pu, err := url.Parse(u)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot unmarshal %q into url: %w", u, err)
|
|
}
|
|
bus[i] = pu
|
|
}
|
|
up.busOriginal = bus
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML marshals up to yaml.
|
|
func (up *URLPrefix) MarshalYAML() (any, error) {
|
|
return up.vOriginal, nil
|
|
}
|
|
|
|
// Regex represents a regex
|
|
type Regex struct {
|
|
re *regexp.Regexp
|
|
|
|
sOriginal string
|
|
}
|
|
|
|
func (r *Regex) match(s string) bool {
|
|
prefix, ok := r.re.LiteralPrefix()
|
|
if ok {
|
|
// Fast path - literal match
|
|
return s == prefix
|
|
}
|
|
if !strings.HasPrefix(s, prefix) {
|
|
return false
|
|
}
|
|
return r.re.MatchString(s)
|
|
}
|
|
|
|
// UnmarshalYAML implements yaml.Unmarshaler
|
|
func (r *Regex) UnmarshalYAML(f func(any) error) error {
|
|
var s string
|
|
if err := f(&s); err != nil {
|
|
return err
|
|
}
|
|
r.sOriginal = s
|
|
|
|
sAnchored := "^(?:" + s + ")$"
|
|
re, err := regexp.Compile(sAnchored)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot build regexp from %q: %w", s, err)
|
|
}
|
|
r.re = re
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML implements yaml.Marshaler.
|
|
func (r *Regex) MarshalYAML() (any, error) {
|
|
return r.sOriginal, nil
|
|
}
|
|
|
|
var (
|
|
configReloads = metrics.NewCounter(`vmauth_config_last_reload_total`)
|
|
configReloadErrors = metrics.NewCounter(`vmauth_config_last_reload_errors_total`)
|
|
configSuccess = metrics.NewGauge(`vmauth_config_last_reload_successful`, nil)
|
|
configTimestamp = metrics.NewCounter(`vmauth_config_last_reload_success_timestamp_seconds`)
|
|
)
|
|
|
|
func initAuthConfig() {
|
|
if len(*authConfigPath) == 0 {
|
|
logger.Fatalf("missing required `-auth.config` command-line flag")
|
|
}
|
|
|
|
// Register SIGHUP handler for config re-read just before readAuthConfig call.
|
|
// This guarantees that the config will be re-read if the signal arrives during readAuthConfig call.
|
|
// See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/1240
|
|
sighupCh := procutil.NewSighupChan()
|
|
|
|
_, err := reloadAuthConfig()
|
|
if err != nil {
|
|
logger.Fatalf("cannot load auth config: %s", err)
|
|
}
|
|
|
|
configSuccess.Set(1)
|
|
configTimestamp.Set(fasttime.UnixTimestamp())
|
|
|
|
stopCh = make(chan struct{})
|
|
authConfigWG.Add(1)
|
|
go func() {
|
|
defer authConfigWG.Done()
|
|
authConfigReloader(sighupCh)
|
|
}()
|
|
}
|
|
|
|
func stopAuthConfig() {
|
|
close(stopCh)
|
|
authConfigWG.Wait()
|
|
}
|
|
|
|
func authConfigReloader(sighupCh <-chan os.Signal) {
|
|
var refreshCh <-chan time.Time
|
|
// initialize auth refresh interval
|
|
if *configCheckInterval > 0 {
|
|
ticker := time.NewTicker(*configCheckInterval)
|
|
defer ticker.Stop()
|
|
refreshCh = ticker.C
|
|
}
|
|
|
|
updateFn := func() {
|
|
configReloads.Inc()
|
|
updated, err := reloadAuthConfig()
|
|
if err != nil {
|
|
logger.Errorf("failed to load auth config; using the last successfully loaded config; error: %s", err)
|
|
configSuccess.Set(0)
|
|
configReloadErrors.Inc()
|
|
return
|
|
}
|
|
configSuccess.Set(1)
|
|
if updated {
|
|
configTimestamp.Set(fasttime.UnixTimestamp())
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-stopCh:
|
|
return
|
|
case <-refreshCh:
|
|
updateFn()
|
|
case <-sighupCh:
|
|
logger.Infof("SIGHUP received; loading -auth.config=%q", *authConfigPath)
|
|
updateFn()
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
// authConfigData stores the yaml definition for this config.
|
|
// authConfigData needs to be updated each time authConfig is updated.
|
|
authConfigData atomic.Pointer[[]byte]
|
|
|
|
// authConfig contains the currently loaded auth config
|
|
authConfig atomic.Pointer[AuthConfig]
|
|
|
|
// authUsers contains the currently loaded auth users
|
|
authUsers atomic.Pointer[map[string]*UserInfo]
|
|
|
|
authConfigWG sync.WaitGroup
|
|
stopCh chan struct{}
|
|
)
|
|
|
|
// reloadAuthConfig loads and applies the config from *authConfigPath.
|
|
// It returns bool value to identify if new config was applied.
|
|
// The config can be not applied if there is a parsing error
|
|
// or if there are no changes to the current authConfig.
|
|
func reloadAuthConfig() (bool, error) {
|
|
data, err := fscore.ReadFileOrHTTP(*authConfigPath)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to read -auth.config=%q: %w", *authConfigPath, err)
|
|
}
|
|
|
|
ok, err := reloadAuthConfigData(data)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to pars -auth.config=%q: %w", *authConfigPath, err)
|
|
}
|
|
if !ok {
|
|
return false, nil
|
|
}
|
|
|
|
mp := authUsers.Load()
|
|
logger.Infof("loaded information about %d users from -auth.config=%q", len(*mp), *authConfigPath)
|
|
return true, nil
|
|
}
|
|
|
|
func reloadAuthConfigData(data []byte) (bool, error) {
|
|
oldData := authConfigData.Load()
|
|
if oldData != nil && bytes.Equal(data, *oldData) {
|
|
// there are no updates in the config - skip reloading.
|
|
return false, nil
|
|
}
|
|
|
|
ac, err := parseAuthConfig(data)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to parse auth config: %w", err)
|
|
}
|
|
|
|
m, err := parseAuthConfigUsers(ac)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to parse users from auth config: %w", err)
|
|
}
|
|
|
|
acPrev := authConfig.Load()
|
|
if acPrev != nil {
|
|
metrics.UnregisterSet(acPrev.ms, true)
|
|
}
|
|
metrics.RegisterSet(ac.ms)
|
|
|
|
authConfig.Store(ac)
|
|
authConfigData.Store(&data)
|
|
authUsers.Store(&m)
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func parseAuthConfig(data []byte) (*AuthConfig, error) {
|
|
data, err := envtemplate.ReplaceBytes(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot expand environment vars: %w", err)
|
|
}
|
|
ac := &AuthConfig{
|
|
ms: metrics.NewSet(),
|
|
}
|
|
if err = yaml.UnmarshalStrict(data, ac); err != nil {
|
|
return nil, fmt.Errorf("cannot unmarshal AuthConfig data: %w", err)
|
|
}
|
|
|
|
ui := ac.UnauthorizedUser
|
|
if ui != nil {
|
|
if ui.Username != "" {
|
|
return nil, fmt.Errorf("field username can't be specified for unauthorized_user section")
|
|
}
|
|
if ui.Password != "" {
|
|
return nil, fmt.Errorf("field password can't be specified for unauthorized_user section")
|
|
}
|
|
if ui.BearerToken != "" {
|
|
return nil, fmt.Errorf("field bearer_token can't be specified for unauthorized_user section")
|
|
}
|
|
if ui.AuthToken != "" {
|
|
return nil, fmt.Errorf("field auth_token can't be specified for unauthorized_user section")
|
|
}
|
|
if ui.Name != "" {
|
|
return nil, fmt.Errorf("field name can't be specified for unauthorized_user section")
|
|
}
|
|
if err := ui.initURLs(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
metricLabels, err := ui.getMetricLabels()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot parse metric_labels for unauthorized_user: %w", err)
|
|
}
|
|
ui.requests = ac.ms.NewCounter(`vmauth_unauthorized_user_requests_total` + metricLabels)
|
|
ui.backendErrors = ac.ms.NewCounter(`vmauth_unauthorized_user_request_backend_errors_total` + metricLabels)
|
|
ui.requestsDuration = ac.ms.NewSummary(`vmauth_unauthorized_user_request_duration_seconds` + metricLabels)
|
|
ui.concurrencyLimitCh = make(chan struct{}, ui.getMaxConcurrentRequests())
|
|
ui.concurrencyLimitReached = ac.ms.NewCounter(`vmauth_unauthorized_user_concurrent_requests_limit_reached_total` + metricLabels)
|
|
_ = ac.ms.NewGauge(`vmauth_unauthorized_user_concurrent_requests_capacity`+metricLabels, func() float64 {
|
|
return float64(cap(ui.concurrencyLimitCh))
|
|
})
|
|
_ = ac.ms.NewGauge(`vmauth_unauthorized_user_concurrent_requests_current`+metricLabels, func() float64 {
|
|
return float64(len(ui.concurrencyLimitCh))
|
|
})
|
|
|
|
rt, err := newRoundTripper(ui.TLSCAFile, ui.TLSCertFile, ui.TLSKeyFile, ui.TLSServerName, ui.TLSInsecureSkipVerify)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
|
|
}
|
|
ui.rt = rt
|
|
}
|
|
return ac, nil
|
|
}
|
|
|
|
func parseAuthConfigUsers(ac *AuthConfig) (map[string]*UserInfo, error) {
|
|
uis := ac.Users
|
|
if len(uis) == 0 && ac.UnauthorizedUser == nil {
|
|
return nil, fmt.Errorf("Missing `users` or `unauthorized_user` sections")
|
|
}
|
|
byAuthToken := make(map[string]*UserInfo, len(uis))
|
|
for i := range uis {
|
|
ui := &uis[i]
|
|
ats, err := getAuthTokens(ui.AuthToken, ui.BearerToken, ui.Username, ui.Password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, at := range ats {
|
|
if uiOld := byAuthToken[at]; uiOld != nil {
|
|
return nil, fmt.Errorf("duplicate auth token=%q found for username=%q, name=%q; the previous one is set for username=%q, name=%q",
|
|
at, ui.Username, ui.Name, uiOld.Username, uiOld.Name)
|
|
}
|
|
}
|
|
if err := ui.initURLs(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
metricLabels, err := ui.getMetricLabels()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot parse metric_labels: %w", err)
|
|
}
|
|
ui.requests = ac.ms.GetOrCreateCounter(`vmauth_user_requests_total` + metricLabels)
|
|
ui.backendErrors = ac.ms.GetOrCreateCounter(`vmauth_user_request_backend_errors_total` + metricLabels)
|
|
ui.requestsDuration = ac.ms.GetOrCreateSummary(`vmauth_user_request_duration_seconds` + metricLabels)
|
|
mcr := ui.getMaxConcurrentRequests()
|
|
ui.concurrencyLimitCh = make(chan struct{}, mcr)
|
|
ui.concurrencyLimitReached = ac.ms.GetOrCreateCounter(`vmauth_user_concurrent_requests_limit_reached_total` + metricLabels)
|
|
_ = ac.ms.GetOrCreateGauge(`vmauth_user_concurrent_requests_capacity`+metricLabels, func() float64 {
|
|
return float64(cap(ui.concurrencyLimitCh))
|
|
})
|
|
_ = ac.ms.GetOrCreateGauge(`vmauth_user_concurrent_requests_current`+metricLabels, func() float64 {
|
|
return float64(len(ui.concurrencyLimitCh))
|
|
})
|
|
|
|
rt, err := newRoundTripper(ui.TLSCAFile, ui.TLSCertFile, ui.TLSKeyFile, ui.TLSServerName, ui.TLSInsecureSkipVerify)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
|
|
}
|
|
ui.rt = rt
|
|
|
|
for _, at := range ats {
|
|
byAuthToken[at] = ui
|
|
}
|
|
}
|
|
return byAuthToken, nil
|
|
}
|
|
|
|
var labelNameRegexp = regexp.MustCompile("^[a-zA-Z_:.][a-zA-Z0-9_:.]*$")
|
|
|
|
func (ui *UserInfo) getMetricLabels() (string, error) {
|
|
name := ui.name()
|
|
if len(name) == 0 && len(ui.MetricLabels) == 0 {
|
|
// fast path
|
|
return "", nil
|
|
}
|
|
labels := make([]string, 0, len(ui.MetricLabels)+1)
|
|
if len(name) > 0 {
|
|
labels = append(labels, fmt.Sprintf(`username=%q`, name))
|
|
}
|
|
for k, v := range ui.MetricLabels {
|
|
if !labelNameRegexp.MatchString(k) {
|
|
return "", fmt.Errorf("incorrect label name=%q, it must match regex=%q for user=%q", k, labelNameRegexp, name)
|
|
}
|
|
labels = append(labels, fmt.Sprintf(`%s=%q`, k, v))
|
|
}
|
|
sort.Strings(labels)
|
|
labelsStr := "{" + strings.Join(labels, ",") + "}"
|
|
return labelsStr, nil
|
|
}
|
|
|
|
func (ui *UserInfo) initURLs() error {
|
|
retryStatusCodes := defaultRetryStatusCodes.Values()
|
|
loadBalancingPolicy := *defaultLoadBalancingPolicy
|
|
dropSrcPathPrefixParts := 0
|
|
discoverBackendIPs := *discoverBackendIPsGlobal
|
|
if ui.URLPrefix != nil {
|
|
if err := ui.URLPrefix.sanitizeAndInitialize(); err != nil {
|
|
return err
|
|
}
|
|
if ui.RetryStatusCodes != nil {
|
|
retryStatusCodes = ui.RetryStatusCodes
|
|
}
|
|
if ui.LoadBalancingPolicy != "" {
|
|
loadBalancingPolicy = ui.LoadBalancingPolicy
|
|
}
|
|
if ui.DropSrcPathPrefixParts != nil {
|
|
dropSrcPathPrefixParts = *ui.DropSrcPathPrefixParts
|
|
}
|
|
if ui.DiscoverBackendIPs != nil {
|
|
discoverBackendIPs = *ui.DiscoverBackendIPs
|
|
}
|
|
ui.URLPrefix.retryStatusCodes = retryStatusCodes
|
|
ui.URLPrefix.dropSrcPathPrefixParts = dropSrcPathPrefixParts
|
|
ui.URLPrefix.discoverBackendIPs = discoverBackendIPs
|
|
if err := ui.URLPrefix.setLoadBalancingPolicy(loadBalancingPolicy); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if ui.DefaultURL != nil {
|
|
if err := ui.DefaultURL.sanitizeAndInitialize(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, e := range ui.URLMaps {
|
|
if len(e.SrcPaths) == 0 && len(e.SrcHosts) == 0 && len(e.SrcQueryArgs) == 0 && len(e.SrcHeaders) == 0 {
|
|
return fmt.Errorf("missing `src_paths`, `src_hosts`, `src_query_args` and `src_headers` in `url_map`")
|
|
}
|
|
if e.URLPrefix == nil {
|
|
return fmt.Errorf("missing `url_prefix` in `url_map`")
|
|
}
|
|
if err := e.URLPrefix.sanitizeAndInitialize(); err != nil {
|
|
return err
|
|
}
|
|
rscs := retryStatusCodes
|
|
lbp := loadBalancingPolicy
|
|
dsp := dropSrcPathPrefixParts
|
|
dbd := discoverBackendIPs
|
|
if e.RetryStatusCodes != nil {
|
|
rscs = e.RetryStatusCodes
|
|
}
|
|
if e.LoadBalancingPolicy != "" {
|
|
lbp = e.LoadBalancingPolicy
|
|
}
|
|
if e.DropSrcPathPrefixParts != nil {
|
|
dsp = *e.DropSrcPathPrefixParts
|
|
}
|
|
if e.DiscoverBackendIPs != nil {
|
|
dbd = *e.DiscoverBackendIPs
|
|
}
|
|
e.URLPrefix.retryStatusCodes = rscs
|
|
if err := e.URLPrefix.setLoadBalancingPolicy(lbp); err != nil {
|
|
return err
|
|
}
|
|
e.URLPrefix.dropSrcPathPrefixParts = dsp
|
|
e.URLPrefix.discoverBackendIPs = dbd
|
|
}
|
|
if len(ui.URLMaps) == 0 && ui.URLPrefix == nil {
|
|
return fmt.Errorf("missing `url_prefix` or `url_map`")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ui *UserInfo) name() string {
|
|
if ui.Name != "" {
|
|
return ui.Name
|
|
}
|
|
if ui.Username != "" {
|
|
return ui.Username
|
|
}
|
|
if ui.BearerToken != "" {
|
|
h := xxhash.Sum64([]byte(ui.BearerToken))
|
|
return fmt.Sprintf("bearer_token:hash:%016X", h)
|
|
}
|
|
if ui.AuthToken != "" {
|
|
h := xxhash.Sum64([]byte(ui.AuthToken))
|
|
return fmt.Sprintf("auth_token:hash:%016X", h)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func getAuthTokens(authToken, bearerToken, username, password string) ([]string, error) {
|
|
if authToken != "" {
|
|
if bearerToken != "" {
|
|
return nil, fmt.Errorf("bearer_token cannot be specified if auth_token is set")
|
|
}
|
|
if username != "" || password != "" {
|
|
return nil, fmt.Errorf("username and password cannot be specified if auth_token is set")
|
|
}
|
|
at := getHTTPAuthToken(authToken)
|
|
return []string{at}, nil
|
|
}
|
|
if bearerToken != "" {
|
|
if username != "" || password != "" {
|
|
return nil, fmt.Errorf("username and password cannot be specified if bearer_token is set")
|
|
}
|
|
// Accept the bearerToken as Basic Auth username with empty password
|
|
at1 := getHTTPAuthBearerToken(bearerToken)
|
|
at2 := getHTTPAuthBasicToken(bearerToken, "")
|
|
return []string{at1, at2}, nil
|
|
}
|
|
if username != "" {
|
|
at := getHTTPAuthBasicToken(username, password)
|
|
return []string{at}, nil
|
|
}
|
|
return nil, fmt.Errorf("missing authorization options; bearer_token or username must be set")
|
|
}
|
|
|
|
func getHTTPAuthToken(authToken string) string {
|
|
return "http_auth:" + authToken
|
|
}
|
|
|
|
func getHTTPAuthBearerToken(bearerToken string) string {
|
|
return "http_auth:Bearer " + bearerToken
|
|
}
|
|
|
|
func getHTTPAuthBasicToken(username, password string) string {
|
|
token := username + ":" + password
|
|
token64 := base64.StdEncoding.EncodeToString([]byte(token))
|
|
return "http_auth:Basic " + token64
|
|
}
|
|
|
|
var defaultHeaderNames = []string{"Authorization"}
|
|
|
|
func getAuthTokensFromRequest(r *http.Request) []string {
|
|
var ats []string
|
|
|
|
// Obtain possible auth tokens from one of the allowed auth headers
|
|
headerNames := *httpAuthHeader
|
|
if len(headerNames) == 0 {
|
|
headerNames = defaultHeaderNames
|
|
}
|
|
for _, headerName := range headerNames {
|
|
if ah := r.Header.Get(headerName); ah != "" {
|
|
if strings.HasPrefix(ah, "Token ") {
|
|
// Handle InfluxDB's proprietary token authentication scheme as a bearer token authentication
|
|
// See https://docs.influxdata.com/influxdb/v2.0/api/
|
|
ah = strings.Replace(ah, "Token", "Bearer", 1)
|
|
}
|
|
at := "http_auth:" + ah
|
|
ats = append(ats, at)
|
|
}
|
|
}
|
|
|
|
return ats
|
|
}
|
|
|
|
func (up *URLPrefix) sanitizeAndInitialize() error {
|
|
for i, bu := range up.busOriginal {
|
|
puNew, err := sanitizeURLPrefix(bu)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
up.busOriginal[i] = puNew
|
|
}
|
|
|
|
// Initialize up.bus
|
|
bus := make([]*backendURL, len(up.busOriginal))
|
|
for i, bu := range up.busOriginal {
|
|
bus[i] = &backendURL{
|
|
url: bu,
|
|
}
|
|
}
|
|
up.bus.Store(&bus)
|
|
|
|
return nil
|
|
}
|
|
|
|
func sanitizeURLPrefix(urlPrefix *url.URL) (*url.URL, error) {
|
|
// Remove trailing '/' from urlPrefix
|
|
for strings.HasSuffix(urlPrefix.Path, "/") {
|
|
urlPrefix.Path = urlPrefix.Path[:len(urlPrefix.Path)-1]
|
|
}
|
|
// Validate urlPrefix
|
|
if urlPrefix.Scheme != "http" && urlPrefix.Scheme != "https" {
|
|
return nil, fmt.Errorf("unsupported scheme for `url_prefix: %q`: %q; must be `http` or `https`", urlPrefix, urlPrefix.Scheme)
|
|
}
|
|
if urlPrefix.Host == "" {
|
|
return nil, fmt.Errorf("missing hostname in `url_prefix %q`", urlPrefix.Host)
|
|
}
|
|
return urlPrefix, nil
|
|
}
|