From 698458b742ab417ff9da8f3dc90288215318fd74 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Thu, 17 Mar 2022 19:42:10 +0200 Subject: [PATCH] lib/httpserver: extract the code responsible for initializing server-side TLS config into netutil.GetServerTLSConfig --- app/vmstorage/transport/server.go | 4 +- lib/httpserver/httpserver.go | 40 ++++-------------- lib/ingestserver/clusternative/server.go | 2 +- lib/ingestserver/graphite/server.go | 2 +- lib/ingestserver/influx/server.go | 2 +- lib/ingestserver/opentsdb/server.go | 2 +- lib/ingestserver/opentsdbhttp/server.go | 2 +- lib/netutil/tcplistener.go | 8 +++- lib/netutil/tls.go | 54 ++++++++++++++++++++++++ 9 files changed, 76 insertions(+), 40 deletions(-) create mode 100644 lib/netutil/tls.go diff --git a/app/vmstorage/transport/server.go b/app/vmstorage/transport/server.go index 48eef1a89..c3720aa47 100644 --- a/app/vmstorage/transport/server.go +++ b/app/vmstorage/transport/server.go @@ -58,11 +58,11 @@ type Server struct { // NewServer returns new Server. func NewServer(vminsertAddr, vmselectAddr string, storage *storage.Storage) (*Server, error) { - vminsertLN, err := netutil.NewTCPListener("vminsert", vminsertAddr) + vminsertLN, err := netutil.NewTCPListener("vminsert", vminsertAddr, nil) if err != nil { return nil, fmt.Errorf("unable to listen vminsertAddr %s: %w", vminsertAddr, err) } - vmselectLN, err := netutil.NewTCPListener("vmselect", vmselectAddr) + vmselectLN, err := netutil.NewTCPListener("vmselect", vmselectAddr, nil) if err != nil { return nil, fmt.Errorf("unable to listen vmselectAddr %s: %w", vmselectAddr, err) } diff --git a/lib/httpserver/httpserver.go b/lib/httpserver/httpserver.go index f0ddcdf2d..35b59094d 100644 --- a/lib/httpserver/httpserver.go +++ b/lib/httpserver/httpserver.go @@ -86,40 +86,18 @@ func Serve(addr string, rh RequestHandler) { } logger.Infof("starting http server at %s://%s/", scheme, hostAddr) logger.Infof("pprof handlers are exposed at %s://%s/debug/pprof/", scheme, hostAddr) - lnTmp, err := netutil.NewTCPListener(scheme, addr) + var tlsConfig *tls.Config + if *tlsEnable { + tc, err := netutil.GetServerTLSConfig("", *tlsCertFile, *tlsKeyFile) + if err != nil { + logger.Fatalf("cannot load TLS cert from -tlsCertFile=%q, -tlsKeyFile=%q: %s", *tlsCertFile, *tlsKeyFile, err) + } + tlsConfig = tc + } + ln, err := netutil.NewTCPListener(scheme, addr, tlsConfig) if err != nil { logger.Fatalf("cannot start http server at %s: %s", addr, err) } - ln := net.Listener(lnTmp) - - if *tlsEnable { - var certLock sync.Mutex - var certDeadline uint64 - var cert *tls.Certificate - c, err := tls.LoadX509KeyPair(*tlsCertFile, *tlsKeyFile) - if err != nil { - logger.Fatalf("cannot load TLS cert from tlsCertFile=%q, tlsKeyFile=%q: %s", *tlsCertFile, *tlsKeyFile, err) - } - cert = &c - cfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - PreferServerCipherSuites: true, - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - certLock.Lock() - defer certLock.Unlock() - if fasttime.UnixTimestamp() > certDeadline { - c, err = tls.LoadX509KeyPair(*tlsCertFile, *tlsKeyFile) - if err != nil { - return nil, fmt.Errorf("cannot load TLS cert from tlsCertFile=%q, tlsKeyFile=%q: %w", *tlsCertFile, *tlsKeyFile, err) - } - certDeadline = fasttime.UnixTimestamp() + 1 - cert = &c - } - return cert, nil - }, - } - ln = tls.NewListener(ln, cfg) - } serveWithListener(addr, ln, rh) } diff --git a/lib/ingestserver/clusternative/server.go b/lib/ingestserver/clusternative/server.go index db1c5d263..8827e7ee5 100644 --- a/lib/ingestserver/clusternative/server.go +++ b/lib/ingestserver/clusternative/server.go @@ -33,7 +33,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(c net.Conn) error) *Server { logger.Infof("starting TCP clusternative server at %q", addr) - lnTCP, err := netutil.NewTCPListener("clusternative", addr) + lnTCP, err := netutil.NewTCPListener("clusternative", addr, nil) if err != nil { logger.Fatalf("cannot start TCP clusternative server at %q: %s", addr, err) } diff --git a/lib/ingestserver/graphite/server.go b/lib/ingestserver/graphite/server.go index 3474200ee..4608ba635 100644 --- a/lib/ingestserver/graphite/server.go +++ b/lib/ingestserver/graphite/server.go @@ -40,7 +40,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(r io.Reader) error) *Server { logger.Infof("starting TCP Graphite server at %q", addr) - lnTCP, err := netutil.NewTCPListener("graphite", addr) + lnTCP, err := netutil.NewTCPListener("graphite", addr, nil) if err != nil { logger.Fatalf("cannot start TCP Graphite server at %q: %s", addr, err) } diff --git a/lib/ingestserver/influx/server.go b/lib/ingestserver/influx/server.go index 9fab5cb02..bb172c68e 100644 --- a/lib/ingestserver/influx/server.go +++ b/lib/ingestserver/influx/server.go @@ -40,7 +40,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(r io.Reader) error) *Server { logger.Infof("starting TCP InfluxDB server at %q", addr) - lnTCP, err := netutil.NewTCPListener("influx", addr) + lnTCP, err := netutil.NewTCPListener("influx", addr, nil) if err != nil { logger.Fatalf("cannot start TCP InfluxDB server at %q: %s", addr, err) } diff --git a/lib/ingestserver/opentsdb/server.go b/lib/ingestserver/opentsdb/server.go index a00a32267..46518e3f9 100644 --- a/lib/ingestserver/opentsdb/server.go +++ b/lib/ingestserver/opentsdb/server.go @@ -43,7 +43,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, telnetInsertHandler func(r io.Reader) error, httpInsertHandler func(req *http.Request) error) *Server { logger.Infof("starting TCP OpenTSDB collector at %q", addr) - lnTCP, err := netutil.NewTCPListener("opentsdb", addr) + lnTCP, err := netutil.NewTCPListener("opentsdb", addr, nil) if err != nil { logger.Fatalf("cannot start TCP OpenTSDB collector at %q: %s", addr, err) } diff --git a/lib/ingestserver/opentsdbhttp/server.go b/lib/ingestserver/opentsdbhttp/server.go index 6eec2f5f1..99e2e924d 100644 --- a/lib/ingestserver/opentsdbhttp/server.go +++ b/lib/ingestserver/opentsdbhttp/server.go @@ -30,7 +30,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(r *http.Request) error) *Server { logger.Infof("starting HTTP OpenTSDB server at %q", addr) - lnTCP, err := netutil.NewTCPListener("opentsdbhttp", addr) + lnTCP, err := netutil.NewTCPListener("opentsdbhttp", addr, nil) if err != nil { logger.Fatalf("cannot start HTTP OpenTSDB collector at %q: %s", addr, err) } diff --git a/lib/netutil/tcplistener.go b/lib/netutil/tcplistener.go index fcc5ce26b..238f755ce 100644 --- a/lib/netutil/tcplistener.go +++ b/lib/netutil/tcplistener.go @@ -1,6 +1,7 @@ package netutil import ( + "crypto/tls" "errors" "flag" "fmt" @@ -13,16 +14,19 @@ import ( var enableTCP6 = flag.Bool("enableTCP6", false, "Whether to enable IPv6 for listening and dialing. By default only IPv4 TCP and UDP is used") -// NewTCPListener returns new TCP listener for the given addr. +// NewTCPListener returns new TCP listener for the given addr and optional tlsConfig. // // name is used for exported metrics. Each listener in the program must have // distinct name. -func NewTCPListener(name, addr string) (*TCPListener, error) { +func NewTCPListener(name, addr string, tlsConfig *tls.Config) (*TCPListener, error) { network := GetTCPNetwork() ln, err := net.Listen(network, addr) if err != nil { return nil, err } + if tlsConfig != nil { + ln = tls.NewListener(ln, tlsConfig) + } tln := &TCPListener{ Listener: ln, diff --git a/lib/netutil/tls.go b/lib/netutil/tls.go new file mode 100644 index 000000000..018154b5e --- /dev/null +++ b/lib/netutil/tls.go @@ -0,0 +1,54 @@ +package netutil + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "sync" + + "github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime" + "github.com/VictoriaMetrics/VictoriaMetrics/lib/fs" +) + +// GetServerTLSConfig returns TLS config for the server with possible client verification (mTLS) if tlsCAFile isn't empty. +func GetServerTLSConfig(tlsCAFile, tlsCertFile, tlsKeyFile string) (*tls.Config, error) { + var certLock sync.Mutex + var certDeadline uint64 + var cert *tls.Certificate + c, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err) + } + cert = &c + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + certLock.Lock() + defer certLock.Unlock() + if fasttime.UnixTimestamp() > certDeadline { + c, err = tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err) + } + certDeadline = fasttime.UnixTimestamp() + 1 + cert = &c + } + return cert, nil + }, + } + if tlsCAFile != "" { + // Enable mTLS ( https://en.wikipedia.org/wiki/Mutual_authentication#mTLS ) + cfg.ClientAuth = tls.RequireAndVerifyClientCert + cp := x509.NewCertPool() + caPEM, err := fs.ReadFileOrHTTP(tlsCAFile) + if err != nil { + return nil, fmt.Errorf("cannot read tlsCAFile=%q: %w", tlsCAFile, err) + } + if !cp.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("cannot parse data for tlsCAFile=%q: %s", tlsCAFile, caPEM) + } + cfg.ClientCAs = cp + } + return cfg, nil +}