package netutil

import (
	"errors"
	"fmt"
	"io"
	"net"
	"sync/atomic"

	"github.com/VictoriaMetrics/VictoriaMetrics/lib/fasttime"
	"github.com/VictoriaMetrics/metrics"
)

type connMetrics struct {
	readCalls    *metrics.Counter
	readBytes    *metrics.Counter
	readErrors   *metrics.Counter
	readTimeouts *metrics.Counter

	writeCalls    *metrics.Counter
	writtenBytes  *metrics.Counter
	writeErrors   *metrics.Counter
	writeTimeouts *metrics.Counter

	closeErrors *metrics.Counter

	conns *metrics.Gauge
}

func (cm *connMetrics) init(ms *metrics.Set, group, name, addr string) {
	cm.readCalls = ms.NewCounter(fmt.Sprintf(`%s_read_calls_total{name=%q, addr=%q}`, group, name, addr))
	cm.readBytes = ms.NewCounter(fmt.Sprintf(`%s_read_bytes_total{name=%q, addr=%q}`, group, name, addr))
	cm.readErrors = ms.NewCounter(fmt.Sprintf(`%s_errors_total{name=%q, addr=%q, type="read"}`, group, name, addr))
	cm.readTimeouts = ms.NewCounter(fmt.Sprintf(`%s_read_timeouts_total{name=%q, addr=%q}`, group, name, addr))

	cm.writeCalls = ms.NewCounter(fmt.Sprintf(`%s_write_calls_total{name=%q, addr=%q}`, group, name, addr))
	cm.writtenBytes = ms.NewCounter(fmt.Sprintf(`%s_written_bytes_total{name=%q, addr=%q}`, group, name, addr))
	cm.writeErrors = ms.NewCounter(fmt.Sprintf(`%s_errors_total{name=%q, addr=%q, type="write"}`, group, name, addr))
	cm.writeTimeouts = ms.NewCounter(fmt.Sprintf(`%s_write_timeouts_total{name=%q, addr=%q}`, group, name, addr))

	cm.closeErrors = ms.NewCounter(fmt.Sprintf(`%s_errors_total{name=%q, addr=%q, type="close"}`, group, name, addr))

	cm.conns = ms.NewGauge(fmt.Sprintf(`%s_conns{name=%q, addr=%q}`, group, name, addr), nil)
}

type statConn struct {
	closeCalls atomic.Uint64

	net.Conn

	cm *connMetrics
}

func (sc *statConn) Read(p []byte) (int, error) {
	startTime := fasttime.UnixTimestamp()
	n, err := sc.Conn.Read(p)
	sc.cm.readCalls.Inc()
	sc.cm.readBytes.Add(n)
	if err != nil && err != io.EOF {
		var ne net.Error
		if errors.As(err, &ne) && ne.Timeout() {
			// Ignore artificial timeout generated by net/http.Server
			// See https://cs.opensource.google/go/go/+/refs/tags/go1.20.5:src/net/http/server.go;l=701
			if fasttime.UnixTimestamp()-startTime > 1 {
				sc.cm.readTimeouts.Inc()
			}
		} else {
			sc.cm.readErrors.Inc()
		}
	}
	return n, err
}

func (sc *statConn) Write(p []byte) (int, error) {
	n, err := sc.Conn.Write(p)
	sc.cm.writeCalls.Inc()
	sc.cm.writtenBytes.Add(n)
	if err != nil {
		var ne net.Error
		if errors.As(err, &ne) && ne.Timeout() {
			sc.cm.writeTimeouts.Inc()
		} else {
			sc.cm.writeErrors.Inc()
		}
	}
	return n, err
}

func (sc *statConn) Close() error {
	n := sc.closeCalls.Add(1)
	if n > 1 {
		// The connection has been already closed.
		return nil
	}
	err := sc.Conn.Close()
	sc.cm.conns.Dec()
	if err != nil {
		sc.cm.closeErrors.Inc()
	}
	return err
}