VictoriaMetrics/lib/netutil/statdial.go

135 lines
3.7 KiB
Go
Raw Normal View History

package netutil
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"github.com/VictoriaMetrics/metrics"
)
// NewStatDialFuncWithDial returns dialer function that registers stats metrics for conns.
func NewStatDialFuncWithDial(metricPrefix string, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
return newStatDialFunc(metricPrefix, dialFunc)
}
// NewStatDialFunc returns dialer function that supports DNS SRV records and registers stats metrics for conns.
func NewStatDialFunc(metricPrefix string) func(ctx context.Context, network, addr string) (net.Conn, error) {
return newStatDialFunc(metricPrefix, DialMaybeSRV)
}
func newStatDialFunc(metricPrefix string, dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, _, addr string) (net.Conn, error) {
sc := &statDialConn{
dialsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dials_total`, metricPrefix)),
dialErrors: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_dial_errors_total`, metricPrefix)),
conns: metrics.GetOrCreateGauge(fmt.Sprintf(`%s_conns`, metricPrefix), nil),
readsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_reads_total`, metricPrefix)),
writesTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_writes_total`, metricPrefix)),
readErrorsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_read_errors_total`, metricPrefix)),
writeErrorsTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_write_errors_total`, metricPrefix)),
bytesReadTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_read_total`, metricPrefix)),
bytesWrittenTotal: metrics.GetOrCreateCounter(fmt.Sprintf(`%s_conn_bytes_written_total`, metricPrefix)),
}
network := GetTCPNetwork()
conn, err := dialFunc(ctx, network, addr)
sc.dialsTotal.Inc()
if err != nil {
sc.dialErrors.Inc()
if !TCP6Enabled() && !isTCPv4Addr(addr) {
err = fmt.Errorf("%w; try -enableTCP6 command-line flag for dialing ipv6 addresses", err)
}
return nil, err
}
sc.Conn = conn
sc.conns.Inc()
return sc, nil
}
}
type statDialConn struct {
closed atomic.Int32
net.Conn
dialsTotal *metrics.Counter
dialErrors *metrics.Counter
conns *metrics.Gauge
readsTotal *metrics.Counter
writesTotal *metrics.Counter
readErrorsTotal *metrics.Counter
writeErrorsTotal *metrics.Counter
bytesReadTotal *metrics.Counter
bytesWrittenTotal *metrics.Counter
}
func (sc *statDialConn) Read(p []byte) (int, error) {
n, err := sc.Conn.Read(p)
sc.readsTotal.Inc()
if err != nil {
sc.readErrorsTotal.Inc()
}
sc.bytesReadTotal.Add(n)
return n, err
}
func (sc *statDialConn) Write(p []byte) (int, error) {
n, err := sc.Conn.Write(p)
sc.writesTotal.Inc()
if err != nil {
sc.writeErrorsTotal.Inc()
}
sc.bytesWrittenTotal.Add(n)
return n, err
}
func (sc *statDialConn) Close() error {
err := sc.Conn.Close()
if sc.closed.Add(1) == 1 {
sc.conns.Dec()
}
return err
}
func isTCPv4Addr(addr string) bool {
s := addr
for i := 0; i < 3; i++ {
n := strings.IndexByte(s, '.')
if n < 0 {
return false
}
if !isUint8NumString(s[:n]) {
return false
}
s = s[n+1:]
}
n := strings.IndexByte(s, ':')
if n < 0 {
return false
}
if !isUint8NumString(s[:n]) {
return false
}
s = s[n+1:]
// Verify TCP port
n, err := strconv.Atoi(s)
if err != nil {
return false
}
return n >= 0 && n < (1<<16)
}
func isUint8NumString(s string) bool {
n, err := strconv.Atoi(s)
if err != nil {
return false
}
return n >= 0 && n < (1<<8)
}