add unit test

This commit is contained in:
leiwingqueen 2024-11-13 22:23:55 +08:00
parent 5ff81cceda
commit f8468e2872
5 changed files with 85 additions and 34 deletions

View file

@ -3,11 +3,10 @@ package main
import (
"bytes"
"fmt"
"gopkg.in/yaml.v2"
"net/url"
"testing"
"gopkg.in/yaml.v2"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
)

View file

@ -180,7 +180,7 @@ func processUserRequest(w http.ResponseWriter, r *http.Request, ui *UserInfo) {
func processRequest(w http.ResponseWriter, r *http.Request, ui *UserInfo) {
u := normalizeURL(r.URL)
up, hc := ui.getURLPrefixAndHeaders(u, r.Header, r)
up, hc := ui.getURLPrefixAndHeaders(u, r.Header, getQueryRangeTime(r))
isDefault := false
if up == nil {
if ui.DefaultURL == nil {

View file

@ -53,16 +53,7 @@ func dropPrefixParts(path string, parts int) string {
return path
}
func (ui *UserInfo) getURLPrefixAndHeaders(u *url.URL, h http.Header, r *http.Request) (*URLPrefix, HeadersConf) {
var start, end, time int64
if r != nil {
req, err := httputils.DumpRequest(r)
if err != nil {
logger.Errorf("cannot dump request: %s", err)
return nil, HeadersConf{}
}
start, end, time = getQueryRangeTime(req)
}
func (ui *UserInfo) getURLPrefixAndHeaders(u *url.URL, h http.Header, queryTimeParams map[string]int64) (*URLPrefix, HeadersConf) {
for _, e := range ui.URLMaps {
if !matchAnyRegex(e.SrcHosts, u.Host) {
continue
@ -76,7 +67,7 @@ func (ui *UserInfo) getURLPrefixAndHeaders(u *url.URL, h http.Header, r *http.Re
if !matchAnyHeader(e.SrcHeaders, h) {
continue
}
if !matchRelativeTimeRange(e.RelativeTimeRangeConfig, start, end, time) {
if !matchRelativeTimeRange(e.RelativeTimeRangeConfig, queryTimeParams) {
continue
}
return e.URLPrefix, e.HeadersConf
@ -87,7 +78,7 @@ func (ui *UserInfo) getURLPrefixAndHeaders(u *url.URL, h http.Header, r *http.Re
return nil, HeadersConf{}
}
func matchRelativeTimeRange(tr *RelativeTimeRangeConfig, startTime int64, endTime int64, time int64) bool {
func matchRelativeTimeRange(tr *RelativeTimeRangeConfig, queryTimeParams map[string]int64) bool {
if tr == nil {
return true
}
@ -95,26 +86,38 @@ func matchRelativeTimeRange(tr *RelativeTimeRangeConfig, startTime int64, endTim
if trStart.IsZero() || trEnd.IsZero() {
return false
}
var start, end, time int64
if queryTimeParams != nil {
for paramName, value := range queryTimeParams {
switch paramName {
case "start":
start = value
case "end":
end = value
case "time":
time = value
}
}
}
// support instant query and query range
return (time >= 0 && time >= trStart.UnixMilli() && time <= trEnd.UnixMilli()) ||
(startTime >= 0 && endTime >= 0 && startTime >= trStart.UnixMilli() && endTime <= trEnd.UnixMilli())
(start >= 0 && end >= 0 && start >= trStart.UnixMilli() && end <= trEnd.UnixMilli())
}
func getQueryRangeTime(r *http.Request) (int64, int64, int64) {
start, err := httputils.GetTime(r, "start", 0)
func getQueryRangeTime(r *http.Request) map[string]int64 {
paramMap := make(map[string]int64)
req, err := httputils.DumpRequest(r)
if err != nil {
return 0, 0, 0
logger.Errorf("cannot dump request: %s", err)
return paramMap
}
end, err := httputils.GetTime(r, "end", 0)
if err != nil {
return 0, 0, 0
for _, param := range queryTimeParams {
time, err := httputils.GetTime(req, param, 0)
if err == nil {
paramMap[param] = time
}
}
// for instant query
t, err := httputils.GetTime(r, "time", 0)
if err != nil {
return 0, 0, 0
}
return start, end, t
return paramMap
}
func matchAnyRegex(rs []*Regex, s string) bool {
@ -178,3 +181,5 @@ func normalizeURL(uOrig *url.URL) *url.URL {
}
return &u
}
var queryTimeParams = []string{"start", "end", "time"}

View file

@ -8,6 +8,7 @@ import (
"reflect"
"strings"
"testing"
"time"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
)
@ -421,6 +422,46 @@ func TestCreateTargetURLFailure(t *testing.T) {
}, "/api/v1/write")
}
func TestMatchRelativeTimeRange(t *testing.T) {
f := func(ui *UserInfo, requestURI string, start int64, end int64, time int64, expectedTarget string) {
t.Helper()
u, err := url.Parse(requestURI)
if err != nil {
t.Fatalf("cannot parse %q: %s", requestURI, err)
}
u = normalizeURL(u)
params := make(map[string]int64)
params["start"] = start
params["end"] = end
params["time"] = time
up, _ := ui.getURLPrefixAndHeaders(u, nil, params)
if up == nil {
t.Fatalf("cannot match available backend: %s", err)
}
bu := up.getBackendURL()
target := mergeURLs(bu.url, u, up.dropSrcPathPrefixParts)
bu.put()
gotTarget := target.String()
if gotTarget != expectedTarget {
t.Fatalf("unexpected target\ngot:\n%q\nwant\n%q", gotTarget, expectedTarget)
}
}
ui := &UserInfo{
URLMaps: []URLMap{
{
RelativeTimeRangeConfig: getRelativeTimeRange(-8*time.Hour, +1*time.Hour),
URLPrefix: mustParseURL("http://srv+vmselect"),
},
},
URLPrefix: mustParseURL("http://non-exist-dns-addr"),
}
f(ui, `/select/0/prometheus/api/v1/query?query=up`, time.Now().UnixMilli()-7*3600*1000, time.Now().UnixMilli(), 0, "http://srv+vmselect/select/0/prometheus/api/v1/query?query=up")
f(ui, `/select/0/prometheus/api/v1/query?query=up`, time.Now().UnixMilli()-11*3600*1000, time.Now().UnixMilli(), 0, "http://non-exist-dns-addr/select/0/prometheus/api/v1/query?query=up")
f(ui, `/select/0/prometheus/api/v1/query?query=up`, 0, 0, time.Now().UnixMilli()-4*3600*1000, "http://srv+vmselect/select/0/prometheus/api/v1/query?query=up")
f(ui, `/select/0/prometheus/api/v1/query?query=up`, 0, 0, time.Now().UnixMilli()-9*3600*1000, "http://non-exist-dns-addr/select/0/prometheus/api/v1/query?query=up")
}
func headersToString(hs []*Header) string {
a := make([]string, len(hs))
for i, h := range hs {
@ -454,3 +495,7 @@ func (r *fakeResolver) LookupIPAddr(_ context.Context, host string) ([]net.IPAdd
func (r *fakeResolver) LookupMX(_ context.Context, _ string) ([]*net.MX, error) {
return nil, nil
}
func getRelativeTimeRange(start time.Duration, end time.Duration) *RelativeTimeRangeConfig {
return &RelativeTimeRangeConfig{&start, &end}
}

View file

@ -46,13 +46,15 @@ func GetTime(r *http.Request, argKey string, defaultMs int64) (int64, error) {
// DumpRequest returns a copy of r with Body replaced with a new io.ReadCloser, as deep copy of the original body.
func DumpRequest(r *http.Request) (*http.Request, error) {
dump := r.Clone(r.Context())
var b bytes.Buffer
_, err := b.ReadFrom(r.Body)
if err != nil {
return nil, err
if r.Body != nil {
var b bytes.Buffer
_, err := b.ReadFrom(r.Body)
if err != nil {
return nil, err
}
r.Body = io.NopCloser(&b)
dump.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
}
r.Body = io.NopCloser(&b)
dump.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
return dump, nil
}