Skip to content

Commit

Permalink
protocol: prove erroneous timeout after timeout is set
Browse files Browse the repository at this point in the history
As reported in #75

Signed-off-by: Pires <[email protected]>
  • Loading branch information
pires committed Sep 8, 2021
1 parent a55009f commit 67d28b3
Showing 1 changed file with 91 additions and 2 deletions.
93 changes: 91 additions & 2 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"testing"
"time"
)
Expand All @@ -29,7 +31,7 @@ func TestPassthrough(t *testing.T) {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
}
defer conn.Close()

conn.Write([]byte("ping"))
Expand Down Expand Up @@ -71,7 +73,7 @@ func TestReadHeaderTimeout(t *testing.T) {

pl := &Listener{
Listener: l,
ReadHeaderTimeout: 1 * time.Millisecond,
ReadHeaderTimeout: time.Millisecond * 250,
}

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -97,6 +99,93 @@ func TestReadHeaderTimeout(t *testing.T) {
recv := make([]byte, 4)
_, err = conn.Read(recv)

if err != nil && !errors.Is(err, os.ErrDeadlineExceeded){
t.Fatal("should timeout")
}
}

func TestReadHeaderTimeoutIsReset(t *testing.T) {
const timeout = time.Millisecond * 250

l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{
Listener: l,
ReadHeaderTimeout: timeout,
}

header := &Header{
Version: 2,
Command: PROXY,
TransportProtocol: TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP("20.2.2.2"),
Port: 2000,
},
}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

// Write out the header!
header.WriteTo(conn)

// Sleep here longer than the configured timeout.
time.Sleep(timeout * 2)

conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}

if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}

// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "10.1.1.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}

h := conn.(*Conn).ProxyHeader()
if !h.EqualsTo(header) {
t.Errorf("bad: %v", h)
}
}

func TestParse_ipv4(t *testing.T) {
Expand Down

0 comments on commit 67d28b3

Please sign in to comment.