lib/{httputils,netutil}: move httputils.GetStatDialFunc to netutil.NewStatDialFunc

- Rename GetStatDialFunc to NewStatDialFunc, since it returns new function with every call
- NewStatDialFunc isn't related to http in any way, so it must be moved from lib/httputils to lib/netutil
- Simplify the implementation of NewStatDialFunc by removing sync.Map from there.
- Use netutil.NewStatDialFunc at app/vmauth and lib/promscrape/discoveryutils
- Use gauge instead of counter type for *_conns metric

This is a follow-up for d7b5062917
Updates https://github.com/VictoriaMetrics/VictoriaMetrics/pull/6299
This commit is contained in:
Aliaksandr Valialkin 2024-07-15 23:00:14 +02:00
parent ad367c17bf
commit a468a6e985
No known key found for this signature in database
GPG key ID: 52C003EE2BCDB9EB
11 changed files with 141 additions and 158 deletions

View file

@ -15,8 +15,8 @@ import (
"github.com/VictoriaMetrics/VictoriaMetrics/lib/awsapi" "github.com/VictoriaMetrics/VictoriaMetrics/lib/awsapi"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger" "github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/persistentqueue" "github.com/VictoriaMetrics/VictoriaMetrics/lib/persistentqueue"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/promauth" "github.com/VictoriaMetrics/VictoriaMetrics/lib/promauth"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/protoparser/common" "github.com/VictoriaMetrics/VictoriaMetrics/lib/protoparser/common"
@ -120,7 +120,7 @@ func newHTTPClient(argIdx int, remoteWriteURL, sanitizedURL string, fq *persiste
logger.Fatalf("cannot initialize AWS Config for -remoteWrite.url=%q: %s", remoteWriteURL, err) logger.Fatalf("cannot initialize AWS Config for -remoteWrite.url=%q: %s", remoteWriteURL, err)
} }
tr := &http.Transport{ tr := &http.Transport{
DialContext: httputils.GetStatDialFunc("vmagent_remotewrite"), DialContext: netutil.NewStatDialFunc("vmagent_remotewrite"),
TLSHandshakeTimeout: tlsHandshakeTimeout.GetOptionalArg(argIdx), TLSHandshakeTimeout: tlsHandshakeTimeout.GetOptionalArg(argIdx),
MaxConnsPerHost: 2 * concurrency, MaxConnsPerHost: 2 * concurrency,
MaxIdleConnsPerHost: 2 * concurrency, MaxIdleConnsPerHost: 2 * concurrency,

View file

@ -12,6 +12,7 @@ import (
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils" "github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger" "github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
) )
var ( var (
@ -100,7 +101,7 @@ func Init(extraParams url.Values) (QuerierBuilder, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create transport: %w", err) return nil, fmt.Errorf("failed to create transport: %w", err)
} }
tr.DialContext = httputils.GetStatDialFunc("vmalert_datasource") tr.DialContext = netutil.NewStatDialFunc("vmalert_datasource")
tr.DisableKeepAlives = *disableKeepAlive tr.DisableKeepAlives = *disableKeepAlive
tr.MaxIdleConnsPerHost = *maxIdleConnections tr.MaxIdleConnsPerHost = *maxIdleConnections
if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost { if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost {

View file

@ -10,6 +10,7 @@ import (
"github.com/VictoriaMetrics/VictoriaMetrics/app/vmalert/utils" "github.com/VictoriaMetrics/VictoriaMetrics/app/vmalert/utils"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils" "github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
) )
var ( var (
@ -70,7 +71,7 @@ func Init() (datasource.QuerierBuilder, error) {
return nil, fmt.Errorf("failed to create transport: %w", err) return nil, fmt.Errorf("failed to create transport: %w", err)
} }
tr.IdleConnTimeout = *idleConnectionTimeout tr.IdleConnTimeout = *idleConnectionTimeout
tr.DialContext = httputils.GetStatDialFunc("vmalert_remoteread") tr.DialContext = netutil.NewStatDialFunc("vmalert_remoteread")
endpointParams, err := flagutil.ParseJSONMap(*oauth2EndpointParams) endpointParams, err := flagutil.ParseJSONMap(*oauth2EndpointParams)
if err != nil { if err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/VictoriaMetrics/VictoriaMetrics/app/vmalert/utils" "github.com/VictoriaMetrics/VictoriaMetrics/app/vmalert/utils"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils" "github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
) )
var ( var (
@ -74,7 +75,7 @@ func Init(ctx context.Context) (*Client, error) {
return nil, fmt.Errorf("failed to create transport: %w", err) return nil, fmt.Errorf("failed to create transport: %w", err)
} }
t.IdleConnTimeout = *idleConnectionTimeout t.IdleConnTimeout = *idleConnectionTimeout
t.DialContext = httputils.GetStatDialFunc("vmalert_remotewrite") t.DialContext = netutil.NewStatDialFunc("vmalert_remotewrite")
endpointParams, err := flagutil.ParseJSONMap(*oauth2EndpointParams) endpointParams, err := flagutil.ParseJSONMap(*oauth2EndpointParams)
if err != nil { if err != nil {

View file

@ -450,7 +450,7 @@ func newRoundTripper(caFileOpt, certFileOpt, keyFileOpt, serverNameOpt string, i
if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost { if tr.MaxIdleConns != 0 && tr.MaxIdleConns < tr.MaxIdleConnsPerHost {
tr.MaxIdleConns = tr.MaxIdleConnsPerHost tr.MaxIdleConns = tr.MaxIdleConnsPerHost
} }
tr.DialContext = netutil.DialMaybeSRV tr.DialContext = netutil.NewStatDialFunc("vmauth_backend")
rt := cfg.NewRoundTripper(tr) rt := cfg.NewRoundTripper(tr)
return rt, nil return rt, nil

View file

@ -1,146 +0,0 @@
package httputils
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
"github.com/VictoriaMetrics/metrics"
)
var statConnMetricsRegistry sync.Map
type statConnMetrics struct {
dialsTotal *metrics.Counter
dialErrors *metrics.Counter
conns *metrics.Counter
connReadsTotal *metrics.Counter
connWritesTotal *metrics.Counter
connReadErrors *metrics.Counter
connWriteErrors *metrics.Counter
connBytesRead *metrics.Counter
connBytesWritten *metrics.Counter
}
func newStatConnMetrics(metricPrefix string) statConnMetrics {
scm := statConnMetrics{}
scm.dialsTotal = metrics.NewCounter(fmt.Sprintf(`%s_dials_total`, metricPrefix))
scm.dialErrors = metrics.NewCounter(fmt.Sprintf(`%s_dial_errors_total`, metricPrefix))
scm.conns = metrics.NewCounter(fmt.Sprintf(`%s_conns`, metricPrefix))
scm.connReadsTotal = metrics.NewCounter(fmt.Sprintf(`%s_conn_reads_total`, metricPrefix))
scm.connWritesTotal = metrics.NewCounter(fmt.Sprintf(`%s_conn_writes_total`, metricPrefix))
scm.connReadErrors = metrics.NewCounter(fmt.Sprintf(`%s_conn_read_errors_total`, metricPrefix))
scm.connWriteErrors = metrics.NewCounter(fmt.Sprintf(`%s_conn_write_errors_total`, metricPrefix))
scm.connBytesRead = metrics.NewCounter(fmt.Sprintf(`%s_conn_bytes_read_total`, metricPrefix))
scm.connBytesWritten = metrics.NewCounter(fmt.Sprintf(`%s_conn_bytes_written_total`, metricPrefix))
return scm
}
// GetStatDialFunc returns dial function that supports DNS SRV records,
// and register stats metrics for conns.
func GetStatDialFunc(metricPrefix string) func(ctx context.Context, network, addr string) (net.Conn, error) {
v, ok := statConnMetricsRegistry.Load(metricPrefix)
if !ok {
v = newStatConnMetrics(metricPrefix)
statConnMetricsRegistry.Store(metricPrefix, v)
}
sm := v.(statConnMetrics)
return func(ctx context.Context, _, addr string) (net.Conn, error) {
network := netutil.GetTCPNetwork()
conn, err := netutil.DialMaybeSRV(ctx, network, addr)
sm.dialsTotal.Inc()
if err != nil {
sm.dialErrors.Inc()
if !netutil.TCP6Enabled() && !isTCPv4Addr(addr) {
err = fmt.Errorf("%w; try -enableTCP6 command-line flag for dialing ipv6 addresses", err)
}
return nil, err
}
sm.conns.Inc()
sc := &statConn{
Conn: conn,
statConnMetrics: sm,
}
return sc, nil
}
}
type statConn struct {
closed atomic.Int32
net.Conn
statConnMetrics
}
func (sc *statConn) Read(p []byte) (int, error) {
n, err := sc.Conn.Read(p)
sc.connReadsTotal.Inc()
if err != nil {
sc.connReadErrors.Inc()
}
sc.connBytesRead.Add(n)
return n, err
}
func (sc *statConn) Write(p []byte) (int, error) {
n, err := sc.Conn.Write(p)
sc.connWritesTotal.Inc()
if err != nil {
sc.connWriteErrors.Inc()
}
sc.connBytesWritten.Add(n)
return n, err
}
func (sc *statConn) Close() error {
err := sc.Conn.Close()
if sc.closed.Add(1) == 1 {
sc.conns.Dec()
}
return err
}
func isTCPv4Addr(addr string) bool {
s := addr
for i := 0; i < 3; i++ {
n := strings.IndexByte(s, '.')
if n < 0 {
return false
}
if !isUint8NumString(s[:n]) {
return false
}
s = s[n+1:]
}
n := strings.IndexByte(s, ':')
if n < 0 {
return false
}
if !isUint8NumString(s[:n]) {
return false
}
s = s[n+1:]
// Verify TCP port
n, err := strconv.Atoi(s)
if err != nil {
return false
}
return n >= 0 && n < (1<<16)
}
func isUint8NumString(s string) bool {
n, err := strconv.Atoi(s)
if err != nil {
return false
}
return n >= 0 && n < (1<<8)
}

View file

@ -24,7 +24,7 @@ type connMetrics struct {
closeErrors *metrics.Counter closeErrors *metrics.Counter
conns *metrics.Counter conns *metrics.Gauge
} }
func (cm *connMetrics) init(ms *metrics.Set, group, name, addr string) { func (cm *connMetrics) init(ms *metrics.Set, group, name, addr string) {
@ -40,7 +40,7 @@ func (cm *connMetrics) init(ms *metrics.Set, group, name, addr string) {
cm.closeErrors = ms.NewCounter(fmt.Sprintf(`%s_errors_total{name=%q, addr=%q, type="close"}`, group, name, addr)) cm.closeErrors = ms.NewCounter(fmt.Sprintf(`%s_errors_total{name=%q, addr=%q, type="close"}`, group, name, addr))
cm.conns = ms.NewCounter(fmt.Sprintf(`%s_conns{name=%q, addr=%q}`, group, name, addr)) cm.conns = ms.NewGauge(fmt.Sprintf(`%s_conns{name=%q, addr=%q}`, group, name, addr), nil)
} }
type statConn struct { type statConn struct {

125
lib/netutil/statdial.go Normal file
View file

@ -0,0 +1,125 @@
package netutil
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"github.com/VictoriaMetrics/metrics"
)
// NewStatDialFunc returns dialer function that supports DNS SRV records and registers stats metrics for conns.
func NewStatDialFunc(metricPrefix string) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, _, addr string) (net.Conn, error) {
sc := &statDialConn{
dialsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dials_total`, metricPrefix)),
dialErrors: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dial_errors_total`, metricPrefix)),
conns: metrics.GetOrCreateGauge(fmt.Sprintf(`%s_conns`, metricPrefix), nil),
readsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_reads_total`, metricPrefix)),
writesTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_writes_total`, metricPrefix)),
readErrorsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_read_errors_total`, metricPrefix)),
writeErrorsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_write_errors_total`, metricPrefix)),
bytesReadTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_read_total`, metricPrefix)),
bytesWrittenTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_written_total`, metricPrefix)),
}
network := GetTCPNetwork()
conn, err := DialMaybeSRV(ctx, network, addr)
sc.dialsTotal.Inc()
if err != nil {
sc.dialErrors.Inc()
if !TCP6Enabled() && !isTCPv4Addr(addr) {
err = fmt.Errorf("%w; try -enableTCP6 command-line flag for dialing ipv6 addresses", err)
}
return nil, err
}
sc.Conn = conn
sc.conns.Inc()
return sc, nil
}
}
type statDialConn struct {
closed atomic.Int32
net.Conn
dialsTotal *metrics.Counter
dialErrors *metrics.Counter
conns *metrics.Gauge
readsTotal *metrics.Counter
writesTotal *metrics.Counter
readErrorsTotal *metrics.Counter
writeErrorsTotal *metrics.Counter
bytesReadTotal *metrics.Counter
bytesWrittenTotal *metrics.Counter
}
func (sc *statDialConn) Read(p []byte) (int, error) {
n, err := sc.Conn.Read(p)
sc.readsTotal.Inc()
if err != nil {
sc.readErrorsTotal.Inc()
}
sc.bytesReadTotal.Add(n)
return n, err
}
func (sc *statDialConn) Write(p []byte) (int, error) {
n, err := sc.Conn.Write(p)
sc.writesTotal.Inc()
if err != nil {
sc.writeErrorsTotal.Inc()
}
sc.bytesWrittenTotal.Add(n)
return n, err
}
func (sc *statDialConn) Close() error {
err := sc.Conn.Close()
if sc.closed.Add(1) == 1 {
sc.conns.Dec()
}
return err
}
func isTCPv4Addr(addr string) bool {
s := addr
for i := 0; i < 3; i++ {
n := strings.IndexByte(s, '.')
if n < 0 {
return false
}
if !isUint8NumString(s[:n]) {
return false
}
s = s[n+1:]
}
n := strings.IndexByte(s, ':')
if n < 0 {
return false
}
if !isUint8NumString(s[:n]) {
return false
}
s = s[n+1:]
// Verify TCP port
n, err := strconv.Atoi(s)
if err != nil {
return false
}
return n >= 0 && n < (1<<16)
}
func isUint8NumString(s string) bool {
n, err := strconv.Atoi(s)
if err != nil {
return false
}
return n >= 0 && n < (1<<8)
}

View file

@ -1,4 +1,4 @@
package httputils package netutil
import ( import (
"testing" "testing"
@ -7,6 +7,7 @@ import (
func TestIsTCPv4Addr(t *testing.T) { func TestIsTCPv4Addr(t *testing.T) {
f := func(addr string, resultExpected bool) { f := func(addr string, resultExpected bool) {
t.Helper() t.Helper()
result := isTCPv4Addr(addr) result := isTCPv4Addr(addr)
if result != resultExpected { if result != resultExpected {
t.Fatalf("unexpected result for isIPv4Addr(%q); got %v; want %v", addr, result, resultExpected) t.Fatalf("unexpected result for isIPv4Addr(%q); got %v; want %v", addr, result, resultExpected)

View file

@ -14,7 +14,7 @@ import (
"github.com/VictoriaMetrics/VictoriaMetrics/lib/bytesutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/bytesutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil" "github.com/VictoriaMetrics/VictoriaMetrics/lib/flagutil"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/httputils" "github.com/VictoriaMetrics/VictoriaMetrics/lib/netutil"
) )
var ( var (
@ -70,7 +70,7 @@ func newClient(ctx context.Context, sw *ScrapeWork) (*client, error) {
IdleConnTimeout: 2 * sw.ScrapeInterval, IdleConnTimeout: 2 * sw.ScrapeInterval,
DisableCompression: *disableCompression || sw.DisableCompression, DisableCompression: *disableCompression || sw.DisableCompression,
DisableKeepAlives: *disableKeepAlive || sw.DisableKeepAlive, DisableKeepAlives: *disableKeepAlive || sw.DisableKeepAlive,
DialContext: httputils.GetStatDialFunc("vm_promscrape"), DialContext: netutil.NewStatDialFunc("vm_promscrape"),
MaxIdleConnsPerHost: 100, MaxIdleConnsPerHost: 100,
MaxResponseHeaderBytes: int64(maxResponseHeadersSize.N), MaxResponseHeaderBytes: int64(maxResponseHeadersSize.N),
}), }),

View file

@ -95,7 +95,7 @@ func NewClient(apiServer string, ac *promauth.Config, proxyURL *proxy.URL, proxy
return nil, fmt.Errorf("cannot parse apiServer=%q: %w", apiServer, err) return nil, fmt.Errorf("cannot parse apiServer=%q: %w", apiServer, err)
} }
dialFunc := netutil.DialMaybeSRV dialFunc := netutil.NewStatDialFunc("vm_promscrape_discovery")
if u.Scheme == "unix" { if u.Scheme == "unix" {
// special case for unix socket connection // special case for unix socket connection
dialAddr := u.Path dialAddr := u.Path