/*
 *
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

// Package s2a provides the S2A transport credentials used by a gRPC
// application.
package s2a

import (
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"net"
	"sync"
	"time"

	"github.com/golang/protobuf/proto"
	"github.com/google/s2a-go/fallback"
	"github.com/google/s2a-go/internal/handshaker"
	"github.com/google/s2a-go/internal/handshaker/service"
	"github.com/google/s2a-go/internal/tokenmanager"
	"github.com/google/s2a-go/internal/v2"
	"github.com/google/s2a-go/retry"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/grpclog"

	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
)

const (
	s2aSecurityProtocol = "tls"
	// defaultTimeout specifies the default server handshake timeout.
	defaultTimeout = 30.0 * time.Second
)

// s2aTransportCreds are the transport credentials required for establishing
// a secure connection using the S2A. They implement the
// credentials.TransportCredentials interface.
type s2aTransportCreds struct {
	info          *credentials.ProtocolInfo
	minTLSVersion commonpb.TLSVersion
	maxTLSVersion commonpb.TLSVersion
	// tlsCiphersuites contains the ciphersuites used in the S2A connection.
	// Note that these are currently unconfigurable.
	tlsCiphersuites []commonpb.Ciphersuite
	// localIdentity should only be used by the client.
	localIdentity *commonpb.Identity
	// localIdentities should only be used by the server.
	localIdentities []*commonpb.Identity
	// targetIdentities should only be used by the client.
	targetIdentities            []*commonpb.Identity
	isClient                    bool
	s2aAddr                     string
	ensureProcessSessionTickets *sync.WaitGroup
}

// NewClientCreds returns a client-side transport credentials object that uses
// the S2A to establish a secure connection with a server.
func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
	if opts == nil {
		return nil, errors.New("nil client options")
	}
	var targetIdentities []*commonpb.Identity
	for _, targetIdentity := range opts.TargetIdentities {
		protoTargetIdentity, err := toProtoIdentity(targetIdentity)
		if err != nil {
			return nil, err
		}
		targetIdentities = append(targetIdentities, protoTargetIdentity)
	}
	localIdentity, err := toProtoIdentity(opts.LocalIdentity)
	if err != nil {
		return nil, err
	}
	if opts.EnableLegacyMode {
		return &s2aTransportCreds{
			info: &credentials.ProtocolInfo{
				SecurityProtocol: s2aSecurityProtocol,
			},
			minTLSVersion: commonpb.TLSVersion_TLS1_3,
			maxTLSVersion: commonpb.TLSVersion_TLS1_3,
			tlsCiphersuites: []commonpb.Ciphersuite{
				commonpb.Ciphersuite_AES_128_GCM_SHA256,
				commonpb.Ciphersuite_AES_256_GCM_SHA384,
				commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
			},
			localIdentity:               localIdentity,
			targetIdentities:            targetIdentities,
			isClient:                    true,
			s2aAddr:                     opts.S2AAddress,
			ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
		}, nil
	}
	verificationMode := getVerificationMode(opts.VerificationMode)
	var fallbackFunc fallback.ClientHandshake
	if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
		fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
	}
	return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
}

// NewServerCreds returns a server-side transport credentials object that uses
// the S2A to establish a secure connection with a client.
func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
	if opts == nil {
		return nil, errors.New("nil server options")
	}
	var localIdentities []*commonpb.Identity
	for _, localIdentity := range opts.LocalIdentities {
		protoLocalIdentity, err := toProtoIdentity(localIdentity)
		if err != nil {
			return nil, err
		}
		localIdentities = append(localIdentities, protoLocalIdentity)
	}
	if opts.EnableLegacyMode {
		return &s2aTransportCreds{
			info: &credentials.ProtocolInfo{
				SecurityProtocol: s2aSecurityProtocol,
			},
			minTLSVersion: commonpb.TLSVersion_TLS1_3,
			maxTLSVersion: commonpb.TLSVersion_TLS1_3,
			tlsCiphersuites: []commonpb.Ciphersuite{
				commonpb.Ciphersuite_AES_128_GCM_SHA256,
				commonpb.Ciphersuite_AES_256_GCM_SHA384,
				commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
			},
			localIdentities: localIdentities,
			isClient:        false,
			s2aAddr:         opts.S2AAddress,
		}, nil
	}
	verificationMode := getVerificationMode(opts.VerificationMode)
	return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, localIdentities, verificationMode, opts.getS2AStream)
}

// ClientHandshake initiates a client-side TLS handshake using the S2A.
func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
	if !c.isClient {
		return nil, nil, errors.New("client handshake called using server transport credentials")
	}

	var cancel context.CancelFunc
	ctx, cancel = context.WithCancel(ctx)
	defer cancel()

	// Connect to the S2A.
	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
	if err != nil {
		grpclog.Infof("Failed to connect to S2A: %v", err)
		return nil, nil, err
	}

	opts := &handshaker.ClientHandshakerOptions{
		MinTLSVersion:               c.minTLSVersion,
		MaxTLSVersion:               c.maxTLSVersion,
		TLSCiphersuites:             c.tlsCiphersuites,
		TargetIdentities:            c.targetIdentities,
		LocalIdentity:               c.localIdentity,
		TargetName:                  serverAuthority,
		EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
	}
	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
	if err != nil {
		grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
		return nil, nil, err
	}
	defer func() {
		if err != nil {
			if closeErr := chs.Close(); closeErr != nil {
				grpclog.Infof("Close failed unexpectedly: %v", err)
				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
			}
		}
	}()

	secConn, authInfo, err := chs.ClientHandshake(context.Background())
	if err != nil {
		grpclog.Infof("Handshake failed: %v", err)
		return nil, nil, err
	}
	return secConn, authInfo, nil
}

// ServerHandshake initiates a server-side TLS handshake using the S2A.
func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
	if c.isClient {
		return nil, nil, errors.New("server handshake called using client transport credentials")
	}

	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
	defer cancel()

	// Connect to the S2A.
	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
	if err != nil {
		grpclog.Infof("Failed to connect to S2A: %v", err)
		return nil, nil, err
	}

	opts := &handshaker.ServerHandshakerOptions{
		MinTLSVersion:   c.minTLSVersion,
		MaxTLSVersion:   c.maxTLSVersion,
		TLSCiphersuites: c.tlsCiphersuites,
		LocalIdentities: c.localIdentities,
	}
	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
	if err != nil {
		grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
		return nil, nil, err
	}
	defer func() {
		if err != nil {
			if closeErr := shs.Close(); closeErr != nil {
				grpclog.Infof("Close failed unexpectedly: %v", err)
				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
			}
		}
	}()

	secConn, authInfo, err := shs.ServerHandshake(context.Background())
	if err != nil {
		grpclog.Infof("Handshake failed: %v", err)
		return nil, nil, err
	}
	return secConn, authInfo, nil
}

func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
	return *c.info
}

func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
	info := *c.info
	var localIdentity *commonpb.Identity
	if c.localIdentity != nil {
		localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
	}
	var localIdentities []*commonpb.Identity
	if c.localIdentities != nil {
		localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
		for i, localIdentity := range c.localIdentities {
			localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
		}
	}
	var targetIdentities []*commonpb.Identity
	if c.targetIdentities != nil {
		targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
		for i, targetIdentity := range c.targetIdentities {
			targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
		}
	}
	return &s2aTransportCreds{
		info:             &info,
		minTLSVersion:    c.minTLSVersion,
		maxTLSVersion:    c.maxTLSVersion,
		tlsCiphersuites:  c.tlsCiphersuites,
		localIdentity:    localIdentity,
		localIdentities:  localIdentities,
		targetIdentities: targetIdentities,
		isClient:         c.isClient,
		s2aAddr:          c.s2aAddr,
	}
}

func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
	c.info.ServerName = serverNameOverride
	return nil
}

// TLSClientConfigOptions specifies parameters for creating client TLS config.
type TLSClientConfigOptions struct {
	// ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
	// 		tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
	//			ServerName: "example.com",
	//		})
	ServerName string
}

// TLSClientConfigFactory defines the interface for a client TLS config factory.
type TLSClientConfigFactory interface {
	Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
}

// NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
	if opts == nil {
		return nil, fmt.Errorf("opts must be non-nil")
	}
	if opts.EnableLegacyMode {
		return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
	}
	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
	if err != nil {
		// The only possible error is: access token not set in the environment,
		// which is okay in environments other than serverless.
		grpclog.Infof("Access token manager not initialized: %v", err)
		return &s2aTLSClientConfigFactory{
			s2av2Address:              opts.S2AAddress,
			transportCreds:            opts.TransportCreds,
			tokenManager:              nil,
			verificationMode:          getVerificationMode(opts.VerificationMode),
			serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
		}, nil
	}
	return &s2aTLSClientConfigFactory{
		s2av2Address:              opts.S2AAddress,
		transportCreds:            opts.TransportCreds,
		tokenManager:              tokenManager,
		verificationMode:          getVerificationMode(opts.VerificationMode),
		serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
	}, nil
}

type s2aTLSClientConfigFactory struct {
	s2av2Address              string
	transportCreds            credentials.TransportCredentials
	tokenManager              tokenmanager.AccessTokenManager
	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
	serverAuthorizationPolicy []byte
}

func (f *s2aTLSClientConfigFactory) Build(
	ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
	serverName := ""
	if opts != nil && opts.ServerName != "" {
		serverName = opts.ServerName
	}
	return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
}

func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
	switch verificationMode {
	case ConnectToGoogle:
		return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
	case Spiffe:
		return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
	default:
		return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
	}
}

// NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
// Example use with http.RoundTripper:
//
//		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
//			S2AAddress:         s2aAddress, // required
//		})
//	 	transport := http.DefaultTransport
//	 	transport.DialTLSContext = dialTLSContext
func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {

	return func(ctx context.Context, network, addr string) (net.Conn, error) {

		fallback := func(err error) (net.Conn, error) {
			if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
				opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
				fbDialer := opts.FallbackOpts.FallbackDialer
				grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
				fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
				if fbErr != nil {
					return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
				}
				return fbConn, nil
			}
			return nil, err
		}

		factory, err := NewTLSClientConfigFactory(opts)
		if err != nil {
			grpclog.Infof("error creating S2A client config factory: %v", err)
			return fallback(err)
		}

		serverName, _, err := net.SplitHostPort(addr)
		if err != nil {
			serverName = addr
		}
		timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
		defer cancel()

		var s2aTLSConfig *tls.Config
		retry.Run(timeoutCtx,
			func() error {
				s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
					ServerName: serverName,
				})
				return err
			})
		if err != nil {
			grpclog.Infof("error building S2A TLS config: %v", err)
			return fallback(err)
		}

		s2aDialer := &tls.Dialer{
			Config: s2aTLSConfig,
		}
		var c net.Conn
		retry.Run(timeoutCtx,
			func() error {
				c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
				return err
			})
		if err != nil {
			grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
			return fallback(err)
		}
		grpclog.Infof("success dialing MTLS to %s with S2A", addr)
		return c, nil
	}
}