package netutil import ( "bytes" "io" "net" "reflect" "testing" ) func TestParseProxyProtocolSuccess(t *testing.T) { f := func(body, wantTail []byte, wantAddr net.Addr) { t.Helper() r := bytes.NewBuffer(body) gotAddr, err := readProxyProto(r) if err != nil { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(gotAddr, wantAddr) { t.Fatalf("ip not match, got: %v, want: %v", gotAddr, wantAddr) } gotTail, err := io.ReadAll(r) if err != nil { t.Fatalf("cannot read tail: %s", err) } if !bytes.Equal(gotTail, wantTail) { t.Fatalf("unexpected tail after parsing proxy protocol\ngot:\n%q\nwant:\n%q", gotTail, wantTail) } } // LOCAL addr f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x20, 0x11, 0x00, 0x0C, 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, nil, nil) // ipv4 f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, // ip data srcid,dstip,srcport,dstport 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, nil, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80}) // ipv4 with payload f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, // ip data 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0, // some payload 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0, }, []byte{0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80}) // ipv6 f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x24, // src and dst ipv6 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // ports 0, 80, 0, 0}, nil, &net.TCPAddr{IP: net.ParseIP("::1"), Port: 80}) } func TestParseProxyProtocolFail(t *testing.T) { f := func(body []byte) { t.Helper() r := bytes.NewBuffer(body) gotAddr, err := readProxyProto(r) if err == nil { t.Fatalf("expected error at input %v", body) } if gotAddr != nil { t.Fatalf("expected ip to be nil, got: %v", gotAddr) } } // too short protocol prefix f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A}) // broken protocol prefix f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21}) // invalid header f([]byte{0x0D, 0x1A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C}) // invalid version f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x31, 0x11, 0x00, 0x0C}) // too long block f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0xff, 0x0C}) // missing bytes in address f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, // ip data srcid,dstip,srcport 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80}) // too short address length f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x08, 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0}) // unsupported family f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x31, 0x00, 0x0C, 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}) // unsupported command f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x22, 0x11, 0x00, 0x0C, 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}) // mismatch ipv6 and ipv4 f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x0C, // ip data srcid,dstip,srcport 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}) // ipv4 udp isn't supported f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x12, 0x00, 0x0C, // ip data srcid,dstip,srcport,dstport 0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}) // ipv6 udp isn't supported f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x22, 0x00, 0x24, // src and dst ipv6 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // ports 0, 80, 0, 0}) }