From 7810375c5f40de2f47d14487d2b82ad7abc9b400 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Sat, 16 Apr 2022 15:51:34 +0300 Subject: [PATCH] lib/httpserver: move the code, which creates tls.Config, into lib/netutil/tls.go This syncs the corresponding code with cluster branch --- lib/httpserver/httpserver.go | 61 ++------------ lib/ingestserver/graphite/server.go | 2 +- lib/ingestserver/influx/server.go | 2 +- lib/ingestserver/opentsdb/server.go | 2 +- lib/ingestserver/opentsdbhttp/server.go | 2 +- lib/netutil/tcplistener.go | 8 +- lib/netutil/tls.go | 80 +++++++++++++++++++ .../tls_test.go} | 2 +- 8 files changed, 98 insertions(+), 61 deletions(-) create mode 100644 lib/netutil/tls.go rename lib/{httpserver/httpserver_test.go => netutil/tls_test.go} (99%) diff --git a/lib/httpserver/httpserver.go b/lib/httpserver/httpserver.go index 6497057e7..27b9d9286 100644 --- a/lib/httpserver/httpserver.go +++ b/lib/httpserver/httpserver.go @@ -91,44 +91,17 @@ func Serve(addr string, rh RequestHandler) { } logger.Infof("starting http server at %s://%s/", scheme, hostAddr) logger.Infof("pprof handlers are exposed at %s://%s/debug/pprof/", scheme, hostAddr) - lnTmp, err := netutil.NewTCPListener(scheme, addr) - if err != nil { - logger.Fatalf("cannot start http server at %s: %s", addr, err) - } - ln := net.Listener(lnTmp) - + var tlsConfig *tls.Config if *tlsEnable { - var certLock sync.Mutex - var certDeadline uint64 - var cert *tls.Certificate - c, err := tls.LoadX509KeyPair(*tlsCertFile, *tlsKeyFile) + tc, err := netutil.GetServerTLSConfig("", *tlsCertFile, *tlsKeyFile, *tlsCipherSuites) if err != nil { logger.Fatalf("cannot load TLS cert from -tlsCertFile=%q, -tlsKeyFile=%q: %s", *tlsCertFile, *tlsKeyFile, err) } - cipherSuites, err := cipherSuitesFromNames(*tlsCipherSuites) - if err != nil { - logger.Fatalf("cannot use TLS cipher suites from -tlsCipherSuites=%q: %s", *tlsCipherSuites, err) - } - cert = &c - cfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - PreferServerCipherSuites: true, - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - certLock.Lock() - defer certLock.Unlock() - if fasttime.UnixTimestamp() > certDeadline { - c, err = tls.LoadX509KeyPair(*tlsCertFile, *tlsKeyFile) - if err != nil { - return nil, fmt.Errorf("cannot load TLS cert from -tlsCertFile=%q, -tlsKeyFile=%q: %w", *tlsCertFile, *tlsKeyFile, err) - } - certDeadline = fasttime.UnixTimestamp() + 1 - cert = &c - } - return cert, nil - }, - CipherSuites: cipherSuites, - } - ln = tls.NewListener(ln, cfg) + tlsConfig = tc + } + ln, err := netutil.NewTCPListener(scheme, addr, tlsConfig) + if err != nil { + logger.Fatalf("cannot start http server at %s: %s", addr, err) } serveWithListener(addr, ln, rh) } @@ -693,23 +666,3 @@ func GetRequestURI(r *http.Request) string { } return requestURI + delimiter + queryArgs } - -func cipherSuitesFromNames(cipherSuiteNames []string) ([]uint16, error) { - if len(cipherSuiteNames) == 0 { - return nil, nil - } - css := tls.CipherSuites() - cssMap := make(map[string]uint16, len(css)) - for _, cs := range css { - cssMap[strings.ToLower(cs.Name)] = cs.ID - } - cipherSuites := make([]uint16, 0, len(cipherSuiteNames)) - for _, name := range cipherSuiteNames { - id, ok := cssMap[strings.ToLower(name)] - if !ok { - return nil, fmt.Errorf("unsupported TLS cipher suite name: %s", name) - } - cipherSuites = append(cipherSuites, id) - } - return cipherSuites, nil -} diff --git a/lib/ingestserver/graphite/server.go b/lib/ingestserver/graphite/server.go index 3474200ee..4608ba635 100644 --- a/lib/ingestserver/graphite/server.go +++ b/lib/ingestserver/graphite/server.go @@ -40,7 +40,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(r io.Reader) error) *Server { logger.Infof("starting TCP Graphite server at %q", addr) - lnTCP, err := netutil.NewTCPListener("graphite", addr) + lnTCP, err := netutil.NewTCPListener("graphite", addr, nil) if err != nil { logger.Fatalf("cannot start TCP Graphite server at %q: %s", addr, err) } diff --git a/lib/ingestserver/influx/server.go b/lib/ingestserver/influx/server.go index 9fab5cb02..bb172c68e 100644 --- a/lib/ingestserver/influx/server.go +++ b/lib/ingestserver/influx/server.go @@ -40,7 +40,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(r io.Reader) error) *Server { logger.Infof("starting TCP InfluxDB server at %q", addr) - lnTCP, err := netutil.NewTCPListener("influx", addr) + lnTCP, err := netutil.NewTCPListener("influx", addr, nil) if err != nil { logger.Fatalf("cannot start TCP InfluxDB server at %q: %s", addr, err) } diff --git a/lib/ingestserver/opentsdb/server.go b/lib/ingestserver/opentsdb/server.go index a00a32267..46518e3f9 100644 --- a/lib/ingestserver/opentsdb/server.go +++ b/lib/ingestserver/opentsdb/server.go @@ -43,7 +43,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, telnetInsertHandler func(r io.Reader) error, httpInsertHandler func(req *http.Request) error) *Server { logger.Infof("starting TCP OpenTSDB collector at %q", addr) - lnTCP, err := netutil.NewTCPListener("opentsdb", addr) + lnTCP, err := netutil.NewTCPListener("opentsdb", addr, nil) if err != nil { logger.Fatalf("cannot start TCP OpenTSDB collector at %q: %s", addr, err) } diff --git a/lib/ingestserver/opentsdbhttp/server.go b/lib/ingestserver/opentsdbhttp/server.go index 6eec2f5f1..99e2e924d 100644 --- a/lib/ingestserver/opentsdbhttp/server.go +++ b/lib/ingestserver/opentsdbhttp/server.go @@ -30,7 +30,7 @@ type Server struct { // MustStop must be called on the returned server when it is no longer needed. func MustStart(addr string, insertHandler func(r *http.Request) error) *Server { logger.Infof("starting HTTP OpenTSDB server at %q", addr) - lnTCP, err := netutil.NewTCPListener("opentsdbhttp", addr) + lnTCP, err := netutil.NewTCPListener("opentsdbhttp", addr, nil) if err != nil { logger.Fatalf("cannot start HTTP OpenTSDB collector at %q: %s", addr, err) } diff --git a/lib/netutil/tcplistener.go b/lib/netutil/tcplistener.go index fcc5ce26b..238f755ce 100644 --- a/lib/netutil/tcplistener.go +++ b/lib/netutil/tcplistener.go @@ -1,6 +1,7 @@ package netutil import ( + "crypto/tls" "errors" "flag" "fmt" @@ -13,16 +14,19 @@ import ( var enableTCP6 = flag.Bool("enableTCP6", false, "Whether to enable IPv6 for listening and dialing. By default only IPv4 TCP and UDP is used") -// NewTCPListener returns new TCP listener for the given addr. +// NewTCPListener returns new TCP listener for the given addr and optional tlsConfig. // // name is used for exported metrics. Each listener in the program must have // distinct name. -func NewTCPListener(name, addr string) (*TCPListener, error) { +func NewTCPListener(name, addr string, tlsConfig *tls.Config) (*TCPListener, error) { network := GetTCPNetwork() ln, err := net.Listen(network, addr) if err != nil { return nil, err } + if tlsConfig != nil { + ln = tls.NewListener(ln, tlsConfig) + } tln := &TCPListener{ Listener: ln, diff --git a/lib/netutil/tls.go b/lib/netutil/tls.go new file mode 100644 index 000000000..3547fbe75 --- /dev/null +++ b/lib/netutil/tls.go @@ -0,0 +1,80 @@ +package netutil + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "strings" + "sync" + + "github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime" + "github.com/VictoriaMetrics/VictoriaMetrics/lib/fs" +) + +// GetServerTLSConfig returns TLS config for the server with possible client verification (mTLS) if tlsCAFile isn't empty. +func GetServerTLSConfig(tlsCAFile, tlsCertFile, tlsKeyFile string, tlsCipherSuites []string) (*tls.Config, error) { + var certLock sync.Mutex + var certDeadline uint64 + var cert *tls.Certificate + c, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err) + } + cipherSuites, err := cipherSuitesFromNames(tlsCipherSuites) + if err != nil { + return nil, fmt.Errorf("cannot use TLS cipher suites from tlsCipherSuites=%q: %w", tlsCipherSuites, err) + } + cert = &c + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + certLock.Lock() + defer certLock.Unlock() + if fasttime.UnixTimestamp() > certDeadline { + c, err = tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + return nil, fmt.Errorf("cannot load TLS cert from certFile=%q, keyFile=%q: %w", tlsCertFile, tlsKeyFile, err) + } + certDeadline = fasttime.UnixTimestamp() + 1 + cert = &c + } + return cert, nil + }, + CipherSuites: cipherSuites, + } + if tlsCAFile != "" { + // Enable mTLS ( https://en.wikipedia.org/wiki/Mutual_authentication#mTLS ) + cfg.ClientAuth = tls.RequireAndVerifyClientCert + cp := x509.NewCertPool() + caPEM, err := fs.ReadFileOrHTTP(tlsCAFile) + if err != nil { + return nil, fmt.Errorf("cannot read tlsCAFile=%q: %w", tlsCAFile, err) + } + if !cp.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("cannot parse data for tlsCAFile=%q: %s", tlsCAFile, caPEM) + } + cfg.ClientCAs = cp + } + return cfg, nil +} + +func cipherSuitesFromNames(cipherSuiteNames []string) ([]uint16, error) { + if len(cipherSuiteNames) == 0 { + return nil, nil + } + css := tls.CipherSuites() + cssMap := make(map[string]uint16, len(css)) + for _, cs := range css { + cssMap[strings.ToLower(cs.Name)] = cs.ID + } + cipherSuites := make([]uint16, 0, len(cipherSuiteNames)) + for _, name := range cipherSuiteNames { + id, ok := cssMap[strings.ToLower(name)] + if !ok { + return nil, fmt.Errorf("unsupported TLS cipher suite name: %s", name) + } + cipherSuites = append(cipherSuites, id) + } + return cipherSuites, nil +} diff --git a/lib/httpserver/httpserver_test.go b/lib/netutil/tls_test.go similarity index 99% rename from lib/httpserver/httpserver_test.go rename to lib/netutil/tls_test.go index 46d10cda8..3e636a151 100644 --- a/lib/httpserver/httpserver_test.go +++ b/lib/netutil/tls_test.go @@ -1,4 +1,4 @@ -package httpserver +package netutil import ( "reflect"