package main import ( "flag" "fmt" "io" "net" "net/http" "net/textproto" "net/url" "os" "strings" "sync" "time" "github.com/VictoriaMetrics/VictoriaMetrics/lib/buildinfo" "github.com/VictoriaMetrics/VictoriaMetrics/lib/bytesutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/envflag" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/httpserver" "github.com/VictoriaMetrics/VictoriaMetrics/lib/logger" "github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/procutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/pushmetrics" "github.com/VictoriaMetrics/metrics" ) var ( httpListenAddr = flag.String("httpListenAddr", ":8427", "TCP address to listen for http connections. See also -httpListenAddr.useProxyProtocol") useProxyProtocol = flag.Bool("httpListenAddr.useProxyProtocol", false, "Whether to use proxy protocol for connections accepted at -httpListenAddr . "+ "See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt . "+ "With enabled proxy protocol http server cannot serve regular /metrics endpoint. Use -pushmetrics.url for metrics pushing") maxIdleConnsPerBackend = flag.Int("maxIdleConnsPerBackend", 100, "The maximum number of idle connections vmauth can open per each backend host. "+ "See also -maxConcurrentRequests") responseTimeout = flag.Duration("responseTimeout", 5*time.Minute, "The timeout for receiving a response from backend") maxConcurrentRequests = flag.Int("maxConcurrentRequests", 1000, "The maximum number of concurrent requests vmauth can process. Other requests are rejected with "+ "'429 Too Many Requests' http status code. See also -maxIdleConnsPerBackend") reloadAuthKey = flag.String("reloadAuthKey", "", "Auth key for /-/reload http endpoint. It must be passed as authKey=...") logInvalidAuthTokens = flag.Bool("logInvalidAuthTokens", false, "Whether to log requests with invalid auth tokens. "+ `Such requests are always counted at vmauth_http_request_errors_total{reason="invalid_auth_token"} metric, which is exposed at /metrics page`) ) func main() { // Write flags and help message to stdout, since it is easier to grep or pipe. flag.CommandLine.SetOutput(os.Stdout) flag.Usage = usage envflag.Parse() buildinfo.Init() logger.Init() pushmetrics.Init() logger.Infof("starting vmauth at %q...", *httpListenAddr) startTime := time.Now() initAuthConfig() go httpserver.Serve(*httpListenAddr, *useProxyProtocol, requestHandler) logger.Infof("started vmauth in %.3f seconds", time.Since(startTime).Seconds()) sig := procutil.WaitForSigterm() logger.Infof("received signal %s", sig) startTime = time.Now() logger.Infof("gracefully shutting down webservice at %q", *httpListenAddr) if err := httpserver.Stop(*httpListenAddr); err != nil { logger.Fatalf("cannot stop the webservice: %s", err) } logger.Infof("successfully shut down the webservice in %.3f seconds", time.Since(startTime).Seconds()) stopAuthConfig() logger.Infof("successfully stopped vmauth in %.3f seconds", time.Since(startTime).Seconds()) } func requestHandler(w http.ResponseWriter, r *http.Request) bool { switch r.URL.Path { case "/-/reload": if !httpserver.CheckAuthFlag(w, r, *reloadAuthKey, "reloadAuthKey") { return true } configReloadRequests.Inc() procutil.SelfSIGHUP() w.WriteHeader(http.StatusOK) return true } authToken := r.Header.Get("Authorization") if authToken == "" { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) http.Error(w, "missing `Authorization` request header", http.StatusUnauthorized) return true } if strings.HasPrefix(authToken, "Token ") { // Handle InfluxDB's proprietary token authentication scheme as a bearer token authentication // See https://docs.influxdata.com/influxdb/v2.0/api/ authToken = strings.Replace(authToken, "Token", "Bearer", 1) } ac := authConfig.Load().(map[string]*UserInfo) ui := ac[authToken] if ui == nil { invalidAuthTokenRequests.Inc() if *logInvalidAuthTokens { err := fmt.Errorf("cannot find the provided auth token %q in config", authToken) err = &httpserver.ErrorWithStatusCode{ Err: err, StatusCode: http.StatusUnauthorized, } httpserver.Errorf(w, r, "%s", err) } else { http.Error(w, "Unauthorized", http.StatusUnauthorized) } return true } ui.requests.Inc() targetURL, headers, err := createTargetURL(ui, r.URL) if err != nil { httpserver.Errorf(w, r, "cannot determine targetURL: %s", err) return true } // Limit the concurrency of requests to backends concurrencyLimitOnce.Do(concurrencyLimitInit) select { case concurrencyLimitCh <- struct{}{}: default: concurrentRequestsLimitReachedTotal.Inc() w.Header().Add("Retry-After", "10") err := &httpserver.ErrorWithStatusCode{ Err: fmt.Errorf("cannot serve more than -maxConcurrentRequests=%d concurrent requests", cap(concurrencyLimitCh)), StatusCode: http.StatusTooManyRequests, } httpserver.Errorf(w, r, "%s", err) return true } processRequest(w, r, targetURL, headers) <-concurrencyLimitCh return true } func processRequest(w http.ResponseWriter, r *http.Request, targetURL *url.URL, headers []Header) { // This code has been copied from net/http/httputil/reverseproxy.go req := sanitizeRequestHeaders(r) req.URL = targetURL for _, h := range headers { req.Header.Set(h.Name, h.Value) } transportOnce.Do(transportInit) res, err := transport.RoundTrip(req) if err != nil { err = &httpserver.ErrorWithStatusCode{ Err: fmt.Errorf("error when proxying the request to %q: %s", targetURL, err), StatusCode: http.StatusBadGateway, } httpserver.Errorf(w, r, "%s", err) return } removeHopHeaders(res.Header) copyHeader(w.Header(), res.Header) w.WriteHeader(res.StatusCode) copyBuf := copyBufPool.Get() copyBuf.B = bytesutil.ResizeNoCopyNoOverallocate(copyBuf.B, 16*1024) _, err = io.CopyBuffer(w, res.Body, copyBuf.B) copyBufPool.Put(copyBuf) _ = res.Body.Close() if err != nil && !netutil.IsTrivialNetworkError(err) { remoteAddr := httpserver.GetQuotedRemoteAddr(r) requestURI := httpserver.GetRequestURI(r) logger.Warnf("remoteAddr: %s; requestURI: %s; error when proxying response body from %s: %s", remoteAddr, requestURI, targetURL, err) return } } var copyBufPool bytesutil.ByteBufferPool func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { dst.Add(k, v) } } } func sanitizeRequestHeaders(r *http.Request) *http.Request { // This code has been copied from net/http/httputil/reverseproxy.go req := r.Clone(r.Context()) removeHopHeaders(req.Header) if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. prior := req.Header["X-Forwarded-For"] if len(prior) > 0 { clientIP = strings.Join(prior, ", ") + ", " + clientIP } req.Header.Set("X-Forwarded-For", clientIP) } return req } func removeHopHeaders(h http.Header) { // remove hop-by-hop headers listed in the "Connection" header of h. // See RFC 7230, section 6.1 for _, f := range h["Connection"] { for _, sf := range strings.Split(f, ",") { if sf = textproto.TrimString(sf); sf != "" { h.Del(sf) } } } // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent // connection, regardless of what the client sent to us. for _, key := range hopHeaders { h.Del(key) } } // Hop-by-hop headers. These are removed when sent to the backend. // As of RFC 7230, hop-by-hop headers are required to appear in the // Connection header field. These are the headers defined by the // obsoleted RFC 2616 (section 13.5.1) and are used for backward // compatibility. var hopHeaders = []string{ "Connection", "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Te", // canonicalized version of "TE" "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 "Transfer-Encoding", "Upgrade", } var ( configReloadRequests = metrics.NewCounter(`vmauth_http_requests_total{path="/-/reload"}`) invalidAuthTokenRequests = metrics.NewCounter(`vmauth_http_request_errors_total{reason="invalid_auth_token"}`) missingRouteRequests = metrics.NewCounter(`vmauth_http_request_errors_total{reason="missing_route"}`) ) var ( transport *http.Transport transportOnce sync.Once ) func transportInit() { tr := http.DefaultTransport.(*http.Transport).Clone() tr.ResponseHeaderTimeout = *responseTimeout // Automatic compression must be disabled in order to fix https://github.com/VictoriaMetrics/VictoriaMetrics/issues/535 tr.DisableCompression = true // Disable HTTP/2.0, since VictoriaMetrics components don't support HTTP/2.0 (because there is no sense in this). tr.ForceAttemptHTTP2 = false tr.MaxIdleConnsPerHost = *maxIdleConnsPerBackend if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost { tr.MaxIdleConns = tr.MaxIdleConnsPerHost } transport = tr } var ( concurrencyLimitCh chan struct{} concurrencyLimitOnce sync.Once ) func concurrencyLimitInit() { concurrencyLimitCh = make(chan struct{}, *maxConcurrentRequests) _ = metrics.NewGauge("vmauth_concurrent_requests_capacity", func() float64 { return float64(*maxConcurrentRequests) }) _ = metrics.NewGauge("vmauth_concurrent_requests_current", func() float64 { return float64(len(concurrencyLimitCh)) }) } var concurrentRequestsLimitReachedTotal = metrics.NewCounter("vmauth_concurrent_requests_limit_reached_total") func usage() { const s = ` vmauth authenticates and authorizes incoming requests and proxies them to VictoriaMetrics. See the docs at https://docs.victoriametrics.com/vmauth.html . ` flagutil.Usage(s) }