From 7c4e460513b134a3ff78f2d1a5c504f215aec621 Mon Sep 17 00:00:00 2001
From: Aliaksandr Valialkin <valyala@gmail.com>
Date: Wed, 21 Apr 2021 10:55:29 +0300
Subject: [PATCH] app/vmauth: parse `url_prefix` only once during config load

---
 app/vmauth/auth_config.go      | 60 ++++++++++++++++++++++------------
 app/vmauth/auth_config_test.go | 35 +++++++++++++++++---
 app/vmauth/main.go             |  6 +---
 app/vmauth/target_url.go       | 36 +++++++++-----------
 app/vmauth/target_url_test.go  | 34 +++++++++----------
 5 files changed, 103 insertions(+), 68 deletions(-)

diff --git a/app/vmauth/auth_config.go b/app/vmauth/auth_config.go
index 919ae2712c..79bcde78cd 100644
--- a/app/vmauth/auth_config.go
+++ b/app/vmauth/auth_config.go
@@ -33,7 +33,7 @@ type UserInfo struct {
 	BearerToken string   `yaml:"bearer_token"`
 	Username    string   `yaml:"username"`
 	Password    string   `yaml:"password"`
-	URLPrefix   string   `yaml:"url_prefix"`
+	URLPrefix   *yamlURL `yaml:"url_prefix"`
 	URLMap      []URLMap `yaml:"url_map"`
 
 	requests *metrics.Counter
@@ -42,7 +42,7 @@ type UserInfo struct {
 // URLMap is a mapping from source paths to target urls.
 type URLMap struct {
 	SrcPaths  []*SrcPath `yaml:"src_paths"`
-	URLPrefix string     `yaml:"url_prefix"`
+	URLPrefix *yamlURL   `yaml:"url_prefix"`
 }
 
 // SrcPath represents an src path
@@ -51,6 +51,27 @@ type SrcPath struct {
 	re        *regexp.Regexp
 }
 
+type yamlURL struct {
+	u *url.URL
+}
+
+func (yu *yamlURL) UnmarshalYAML(f func(interface{}) error) error {
+	var s string
+	if err := f(&s); err != nil {
+		return err
+	}
+	u, err := url.Parse(s)
+	if err != nil {
+		return fmt.Errorf("cannot unmarshal %q into url: %w", s, err)
+	}
+	yu.u = u
+	return nil
+}
+
+func (yu *yamlURL) MarshalYAML() (interface{}, error) {
+	return yu.u.String(), nil
+}
+
 func (sp *SrcPath) match(s string) bool {
 	prefix, ok := sp.re.LiteralPrefix()
 	if ok {
@@ -173,24 +194,27 @@ func parseAuthConfig(data []byte) (map[string]*UserInfo, error) {
 		if byAuthToken[authToken] != nil {
 			return nil, fmt.Errorf("duplicate auth token found for bearer_token=%q, username=%q: %q", authToken, ui.BearerToken, ui.Username)
 		}
-		if len(ui.URLPrefix) > 0 {
-			urlPrefix, err := sanitizeURLPrefix(ui.URLPrefix)
+		if ui.URLPrefix != nil {
+			urlPrefix, err := sanitizeURLPrefix(ui.URLPrefix.u)
 			if err != nil {
 				return nil, err
 			}
-			ui.URLPrefix = urlPrefix
+			ui.URLPrefix.u = urlPrefix
 		}
 		for _, e := range ui.URLMap {
 			if len(e.SrcPaths) == 0 {
-				return nil, fmt.Errorf("missing `src_paths`")
+				return nil, fmt.Errorf("missing `src_paths` in `url_map`")
 			}
-			urlPrefix, err := sanitizeURLPrefix(e.URLPrefix)
+			if e.URLPrefix == nil {
+				return nil, fmt.Errorf("missing `url_prefix` in `url_map`")
+			}
+			urlPrefix, err := sanitizeURLPrefix(e.URLPrefix.u)
 			if err != nil {
 				return nil, err
 			}
-			e.URLPrefix = urlPrefix
+			e.URLPrefix.u = urlPrefix
 		}
-		if len(ui.URLMap) == 0 && len(ui.URLPrefix) == 0 {
+		if len(ui.URLMap) == 0 && ui.URLPrefix == nil {
 			return nil, fmt.Errorf("missing `url_prefix`")
 		}
 		if ui.BearerToken != "" {
@@ -218,21 +242,17 @@ func getAuthToken(bearerToken, username, password string) string {
 	return "Basic " + token64
 }
 
-func sanitizeURLPrefix(urlPrefix string) (string, error) {
+func sanitizeURLPrefix(urlPrefix *url.URL) (*url.URL, error) {
 	// Remove trailing '/' from urlPrefix
-	for strings.HasSuffix(urlPrefix, "/") {
-		urlPrefix = urlPrefix[:len(urlPrefix)-1]
+	for strings.HasSuffix(urlPrefix.Path, "/") {
+		urlPrefix.Path = urlPrefix.Path[:len(urlPrefix.Path)-1]
 	}
 	// Validate urlPrefix
-	target, err := url.Parse(urlPrefix)
-	if err != nil {
-		return "", fmt.Errorf("invalid `url_prefix: %q`: %w", urlPrefix, err)
+	if urlPrefix.Scheme != "http" && urlPrefix.Scheme != "https" {
+		return nil, fmt.Errorf("unsupported scheme for `url_prefix: %q`: %q; must be `http` or `https`", urlPrefix, urlPrefix.Scheme)
 	}
-	if target.Scheme != "http" && target.Scheme != "https" {
-		return "", fmt.Errorf("unsupported scheme for `url_prefix: %q`: %q; must be `http` or `https`", urlPrefix, target.Scheme)
-	}
-	if target.Host == "" {
-		return "", fmt.Errorf("missing hostname in `url_prefix %q`", urlPrefix)
+	if urlPrefix.Host == "" {
+		return nil, fmt.Errorf("missing hostname in `url_prefix %q`", urlPrefix.Host)
 	}
 	return urlPrefix, nil
 }
diff --git a/app/vmauth/auth_config_test.go b/app/vmauth/auth_config_test.go
index 759136e5fe..e192189c5b 100644
--- a/app/vmauth/auth_config_test.go
+++ b/app/vmauth/auth_config_test.go
@@ -3,6 +3,7 @@ package main
 import (
 	"bytes"
 	"fmt"
+	"net/url"
 	"regexp"
 	"testing"
 
@@ -55,6 +56,11 @@ users:
 - username: foo
   url_prefix: http:///bar
 `)
+	f(`
+users:
+- username: foo
+  url_prefix: [bar]
+`)
 
 	// Username and bearer_token in a single config
 	f(`
@@ -102,6 +108,15 @@ users:
   - src_paths: ["/foo/bar"]
 `)
 
+	// Invalid url_prefix in url_map
+	f(`
+users:
+- username: a
+  url_map:
+  - src_paths: ["/foo/bar"]
+    url_prefix: foo.bar
+`)
+
 	// Missing src_paths in url_map
 	f(`
 users:
@@ -143,7 +158,7 @@ users:
 		getAuthToken("", "foo", "bar"): {
 			Username:  "foo",
 			Password:  "bar",
-			URLPrefix: "http://aaa:343/bbb",
+			URLPrefix: mustParseURL("http://aaa:343/bbb"),
 		},
 	})
 
@@ -157,11 +172,11 @@ users:
 `, map[string]*UserInfo{
 		getAuthToken("", "foo", ""): {
 			Username:  "foo",
-			URLPrefix: "http://foo",
+			URLPrefix: mustParseURL("http://foo"),
 		},
 		getAuthToken("", "bar", ""): {
 			Username:  "bar",
-			URLPrefix: "https://bar/x",
+			URLPrefix: mustParseURL("https://bar/x"),
 		},
 	})
 
@@ -180,11 +195,11 @@ users:
 			URLMap: []URLMap{
 				{
 					SrcPaths:  getSrcPaths([]string{"/api/v1/query", "/api/v1/query_range", "/api/v1/label/[^./]+/.+"}),
-					URLPrefix: "http://vmselect/select/0/prometheus",
+					URLPrefix: mustParseURL("http://vmselect/select/0/prometheus"),
 				},
 				{
 					SrcPaths:  getSrcPaths([]string{"/api/v1/write"}),
-					URLPrefix: "http://vminsert/insert/0/prometheus",
+					URLPrefix: mustParseURL("http://vminsert/insert/0/prometheus"),
 				},
 			},
 		},
@@ -222,3 +237,13 @@ func areEqualConfigs(a, b map[string]*UserInfo) error {
 	}
 	return nil
 }
+
+func mustParseURL(u string) *yamlURL {
+	pu, err := url.Parse(u)
+	if err != nil {
+		panic(fmt.Errorf("BUG: cannot parse %q: %w", u, err))
+	}
+	return &yamlURL{
+		u: pu,
+	}
+}
diff --git a/app/vmauth/main.go b/app/vmauth/main.go
index a8f59f03cc..bc6ae38ba4 100644
--- a/app/vmauth/main.go
+++ b/app/vmauth/main.go
@@ -65,11 +65,7 @@ func requestHandler(w http.ResponseWriter, r *http.Request) bool {
 		httpserver.Errorf(w, r, "cannot determine targetURL: %s", err)
 		return true
 	}
-	if _, err := url.Parse(targetURL); err != nil {
-		httpserver.Errorf(w, r, "invalid targetURL=%q: %s", targetURL, err)
-		return true
-	}
-	r.Header.Set("vm-target-url", targetURL)
+	r.Header.Set("vm-target-url", targetURL.String())
 	reverseProxy.ServeHTTP(w, r)
 	return true
 }
diff --git a/app/vmauth/target_url.go b/app/vmauth/target_url.go
index 28fbdb47cb..a361b561f2 100644
--- a/app/vmauth/target_url.go
+++ b/app/vmauth/target_url.go
@@ -7,37 +7,31 @@ import (
 	"strings"
 )
 
-func mergeURLs(uiURL string, requestURI *url.URL) (string, error) {
-	prefixURL, err := url.Parse(uiURL)
-	if err != nil {
-		return "", fmt.Errorf("BUG - cannot parse userInfo url: %q, err: %w", uiURL, err)
-	}
-	prefixURL.Path += requestURI.Path
+func mergeURLs(uiURL, requestURI *url.URL) *url.URL {
+	targetURL := *uiURL
+	targetURL.Path += requestURI.Path
 	requestParams := requestURI.Query()
 	// fast path
 	if len(requestParams) == 0 {
-		return prefixURL.String(), nil
+		return &targetURL
 	}
 	// merge query parameters from requests.
-	userInfoParams := prefixURL.Query()
+	uiParams := targetURL.Query()
 	for k, v := range requestParams {
 		// skip clashed query params from original request
-		if exist := userInfoParams.Get(k); len(exist) > 0 {
+		if exist := uiParams.Get(k); len(exist) > 0 {
 			continue
 		}
 		for i := range v {
-			userInfoParams.Add(k, v[i])
+			uiParams.Add(k, v[i])
 		}
 	}
-	prefixURL.RawQuery = userInfoParams.Encode()
-	return prefixURL.String(), nil
+	targetURL.RawQuery = uiParams.Encode()
+	return &targetURL
 }
 
-func createTargetURL(ui *UserInfo, uOrig *url.URL) (string, error) {
-	u, err := url.Parse(uOrig.String())
-	if err != nil {
-		return "", fmt.Errorf("cannot make a copy of %q: %w", u, err)
-	}
+func createTargetURL(ui *UserInfo, uOrig *url.URL) (*url.URL, error) {
+	u := *uOrig
 	// Prevent from attacks with using `..` in r.URL.Path
 	u.Path = path.Clean(u.Path)
 	if !strings.HasPrefix(u.Path, "/") {
@@ -46,12 +40,12 @@ func createTargetURL(ui *UserInfo, uOrig *url.URL) (string, error) {
 	for _, e := range ui.URLMap {
 		for _, sp := range e.SrcPaths {
 			if sp.match(u.Path) {
-				return mergeURLs(e.URLPrefix, u)
+				return mergeURLs(e.URLPrefix.u, &u), nil
 			}
 		}
 	}
-	if len(ui.URLPrefix) > 0 {
-		return mergeURLs(ui.URLPrefix, u)
+	if ui.URLPrefix != nil {
+		return mergeURLs(ui.URLPrefix.u, &u), nil
 	}
-	return "", fmt.Errorf("missing route for %q", u)
+	return nil, fmt.Errorf("missing route for %q", u.String())
 }
diff --git a/app/vmauth/target_url_test.go b/app/vmauth/target_url_test.go
index 9c8d454a53..18d7e53613 100644
--- a/app/vmauth/target_url_test.go
+++ b/app/vmauth/target_url_test.go
@@ -16,28 +16,28 @@ func TestCreateTargetURLSuccess(t *testing.T) {
 		if err != nil {
 			t.Fatalf("unexpected error: %s", err)
 		}
-		if target != expectedTarget {
+		if target.String() != expectedTarget {
 			t.Fatalf("unexpected target; got %q; want %q", target, expectedTarget)
 		}
 	}
 	// Simple routing with `url_prefix`
 	f(&UserInfo{
-		URLPrefix: "http://foo.bar",
+		URLPrefix: mustParseURL("http://foo.bar"),
 	}, "", "http://foo.bar/.")
 	f(&UserInfo{
-		URLPrefix: "http://foo.bar",
+		URLPrefix: mustParseURL("http://foo.bar"),
 	}, "/", "http://foo.bar/")
 	f(&UserInfo{
-		URLPrefix: "http://foo.bar",
+		URLPrefix: mustParseURL("http://foo.bar"),
 	}, "a/b?c=d", "http://foo.bar/a/b?c=d")
 	f(&UserInfo{
-		URLPrefix: "https://sss:3894/x/y",
+		URLPrefix: mustParseURL("https://sss:3894/x/y"),
 	}, "/z", "https://sss:3894/x/y/z")
 	f(&UserInfo{
-		URLPrefix: "https://sss:3894/x/y",
+		URLPrefix: mustParseURL("https://sss:3894/x/y"),
 	}, "/../../aaa", "https://sss:3894/x/y/aaa")
 	f(&UserInfo{
-		URLPrefix: "https://sss:3894/x/y",
+		URLPrefix: mustParseURL("https://sss:3894/x/y"),
 	}, "/./asd/../../aaa?a=d&s=s/../d", "https://sss:3894/x/y/aaa?a=d&s=s%2F..%2Fd")
 
 	// Complex routing with `url_map`
@@ -45,14 +45,14 @@ func TestCreateTargetURLSuccess(t *testing.T) {
 		URLMap: []URLMap{
 			{
 				SrcPaths:  getSrcPaths([]string{"/api/v1/query"}),
-				URLPrefix: "http://vmselect/0/prometheus",
+				URLPrefix: mustParseURL("http://vmselect/0/prometheus"),
 			},
 			{
 				SrcPaths:  getSrcPaths([]string{"/api/v1/write"}),
-				URLPrefix: "http://vminsert/0/prometheus",
+				URLPrefix: mustParseURL("http://vminsert/0/prometheus"),
 			},
 		},
-		URLPrefix: "http://default-server",
+		URLPrefix: mustParseURL("http://default-server"),
 	}
 	f(ui, "/api/v1/query?query=up", "http://vmselect/0/prometheus/api/v1/query?query=up")
 	f(ui, "/api/v1/write", "http://vminsert/0/prometheus/api/v1/write")
@@ -63,14 +63,14 @@ func TestCreateTargetURLSuccess(t *testing.T) {
 		URLMap: []URLMap{
 			{
 				SrcPaths:  getSrcPaths([]string{"/api/v1/query(_range)?", "/api/v1/label/[^/]+/values"}),
-				URLPrefix: "http://vmselect/0/prometheus",
+				URLPrefix: mustParseURL("http://vmselect/0/prometheus"),
 			},
 			{
 				SrcPaths:  getSrcPaths([]string{"/api/v1/write"}),
-				URLPrefix: "http://vminsert/0/prometheus",
+				URLPrefix: mustParseURL("http://vminsert/0/prometheus"),
 			},
 		},
-		URLPrefix: "http://default-server",
+		URLPrefix: mustParseURL("http://default-server"),
 	}
 	f(ui, "/api/v1/query?query=up", "http://vmselect/0/prometheus/api/v1/query?query=up")
 	f(ui, "/api/v1/query_range?query=up", "http://vmselect/0/prometheus/api/v1/query_range?query=up")
@@ -78,10 +78,10 @@ func TestCreateTargetURLSuccess(t *testing.T) {
 	f(ui, "/api/v1/write", "http://vminsert/0/prometheus/api/v1/write")
 	f(ui, "/api/v1/foo/bar", "http://default-server/api/v1/foo/bar")
 	f(&UserInfo{
-		URLPrefix: "http://foo.bar?extra_label=team=dev",
+		URLPrefix: mustParseURL("http://foo.bar?extra_label=team=dev"),
 	}, "/api/v1/query", "http://foo.bar/api/v1/query?extra_label=team=dev")
 	f(&UserInfo{
-		URLPrefix: "http://foo.bar?extra_label=team=mobile",
+		URLPrefix: mustParseURL("http://foo.bar?extra_label=team=mobile"),
 	}, "/api/v1/query?extra_label=team=dev", "http://foo.bar/api/v1/query?extra_label=team%3Dmobile")
 
 }
@@ -97,7 +97,7 @@ func TestCreateTargetURLFailure(t *testing.T) {
 		if err == nil {
 			t.Fatalf("expecting non-nil error")
 		}
-		if target != "" {
+		if target != nil {
 			t.Fatalf("unexpected target=%q; want empty string", target)
 		}
 	}
@@ -106,7 +106,7 @@ func TestCreateTargetURLFailure(t *testing.T) {
 		URLMap: []URLMap{
 			{
 				SrcPaths:  getSrcPaths([]string{"/api/v1/query"}),
-				URLPrefix: "http://foobar/baz",
+				URLPrefix: mustParseURL("http://foobar/baz"),
 			},
 		},
 	}, "/api/v1/write")