diff --git a/konnectivity-client/pkg/client/client.go b/konnectivity-client/pkg/client/client.go index 3db8c0f8c..f901941b3 100644 --- a/konnectivity-client/pkg/client/client.go +++ b/konnectivity-client/pkg/client/client.go @@ -24,6 +24,7 @@ import ( "math/rand" "net" "sync" + "sync/atomic" "time" "google.golang.org/grpc" @@ -42,7 +43,7 @@ type Tunnel interface { } type dialResult struct { - err string + err *dialFailure connid int64 } @@ -53,13 +54,70 @@ type pendingDial struct { cancelCh <-chan struct{} } +// TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+ +type pendingDialManager struct { + pendingDials map[int64]pendingDial + mutex sync.RWMutex +} + +func (p *pendingDialManager) add(dialID int64, pd pendingDial) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.pendingDials[dialID] = pd +} + +func (p *pendingDialManager) remove(dialID int64) { + p.mutex.Lock() + defer p.mutex.Unlock() + delete(p.pendingDials, dialID) +} + +func (p *pendingDialManager) get(dialID int64) (pendingDial, bool) { + p.mutex.RLock() + defer p.mutex.RUnlock() + pd, ok := p.pendingDials[dialID] + return pd, ok +} + +// TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+ +type connectionManager struct { + conns map[int64]*conn + mutex sync.RWMutex +} + +func (cm *connectionManager) add(connID int64, c *conn) { + cm.mutex.Lock() + defer cm.mutex.Unlock() + cm.conns[connID] = c +} + +func (cm *connectionManager) remove(connID int64) { + cm.mutex.Lock() + defer cm.mutex.Unlock() + delete(cm.conns, connID) +} + +func (cm *connectionManager) get(connID int64) (*conn, bool) { + cm.mutex.RLock() + defer cm.mutex.RUnlock() + c, ok := cm.conns[connID] + return c, ok +} + +func (cm *connectionManager) closeAll() { + cm.mutex.Lock() + defer cm.mutex.Unlock() + for _, conn := range cm.conns { + close(conn.readCh) + } +} + // grpcTunnel implements Tunnel type grpcTunnel struct { - stream client.ProxyService_ProxyClient - pendingDial map[int64]pendingDial - conns map[int64]*conn - pendingDialLock sync.RWMutex - connsLock sync.RWMutex + stream client.ProxyService_ProxyClient + clientConn clientConn + pendingDial pendingDialManager + conns connectionManager // The tunnel will be closed if the caller fails to read via conn.Read() // more than readTimeoutSeconds after a packet has been received. @@ -68,6 +126,11 @@ type grpcTunnel struct { // The done channel is closed after the tunnel has cleaned up all connections and is no longer // serving. done chan struct{} + + // closing is an atomic bool represented as a 0 or 1, and set to true when the tunnel is being closed. + // closing should only be accessed through atomic methods. + // TODO: switch this to an atomic.Bool once the client is exclusively buit with go1.19+ + closing uint32 } type clientConn interface { @@ -106,42 +169,39 @@ func CreateSingleUseGrpcTunnelWithContext(createCtx, tunnelCtx context.Context, return nil, err } - tunnel := newUnstartedTunnel(stream) + tunnel := newUnstartedTunnel(stream, c) - go tunnel.serve(tunnelCtx, c) + go tunnel.serve(tunnelCtx) return tunnel, nil } -func newUnstartedTunnel(stream client.ProxyService_ProxyClient) *grpcTunnel { +func newUnstartedTunnel(stream client.ProxyService_ProxyClient, c clientConn) *grpcTunnel { return &grpcTunnel{ stream: stream, - pendingDial: make(map[int64]pendingDial), - conns: make(map[int64]*conn), + clientConn: c, + pendingDial: pendingDialManager{pendingDials: make(map[int64]pendingDial)}, + conns: connectionManager{conns: make(map[int64]*conn)}, readTimeoutSeconds: 10, done: make(chan struct{}), } } -func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { +func (t *grpcTunnel) serve(tunnelCtx context.Context) { defer func() { - c.Close() + t.clientConn.Close() // A connection in t.conns after serve() returns means // we never received a CLOSE_RSP for it, so we need to // close any channels remaining for these connections. - t.connsLock.Lock() - for _, conn := range t.conns { - close(conn.readCh) - } - t.connsLock.Unlock() + t.conns.closeAll() close(t.done) }() for { pkt, err := t.stream.Recv() - if err == io.EOF { + if err == io.EOF || t.isClosing() { return } if err != nil || pkt == nil { @@ -154,28 +214,29 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { switch pkt.Type { case client.PacketType_DIAL_RSP: resp := pkt.GetDialResponse() - t.pendingDialLock.RLock() - pendingDial, ok := t.pendingDial[resp.Random] - t.pendingDialLock.RUnlock() + pendingDial, ok := t.pendingDial.get(resp.Random) if !ok { + // If the DIAL_RSP does not match a pending dial, it means one of two things: + // 1. There was a second DIAL_RSP for the connection request (this is very unlikely but possible) + // 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context + // + // In either scenario, we should return here and close the tunnel as it is no longer needed. klog.V(1).InfoS("DialResp not recognized; dropped", "connectionID", resp.ConnectID, "dialID", resp.Random) return } else { - result := dialResult{ - err: resp.Error, - connid: resp.ConnectID, + result := dialResult{connid: resp.ConnectID} + if resp.Error != "" { + result.err = &dialFailure{resp.Error, DialFailureEndpoint} } select { // try to send to the result channel case pendingDial.resultCh <- result: // unblock if the cancel channel is closed case <-pendingDial.cancelCh: - // If there are no readers of the pending dial channel above, it means one of two things: - // 1. There was a second DIAL_RSP for the connection request (this is very unlikely but possible) - // 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context - // - // In either scenario, we should return here as this tunnel is no longer needed. + // Note: this condition can only be hit by a race condition where the + // DialContext() returns early (timeout) after the pendingDial is already + // fetched here, but before the result is sent. klog.V(1).InfoS("Pending dial has been cancelled; dropped", "connectionID", resp.ConnectID, "dialID", resp.Random) return case <-tunnelCtx.Done(): @@ -189,12 +250,36 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { return } + case client.PacketType_DIAL_CLS: + resp := pkt.GetCloseDial() + pendingDial, ok := t.pendingDial.get(resp.Random) + + if !ok { + // If the DIAL_CLS does not match a pending dial, it means one of two things: + // 1. There was a DIAL_CLS receieved after a DIAL_RSP (unlikely but possible) + // 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context + // + // In either scenario, we should return here and close the tunnel as it is no longer needed. + klog.V(1).InfoS("DIAL_CLS after dial finished", "dialID", resp.Random) + } else { + result := dialResult{ + err: &dialFailure{"dial closed", DialFailureDialClosed}, + } + select { + case pendingDial.resultCh <- result: + case <-pendingDial.cancelCh: + // Note: this condition can only be hit by a race condition where the + // DialContext() returns early (timeout) after the pendingDial is already + // fetched here, but before the result is sent. + case <-tunnelCtx.Done(): + } + } + return // Stop serving & close the tunnel. + case client.PacketType_DATA: resp := pkt.GetData() // TODO: flow control - t.connsLock.RLock() - conn, ok := t.conns[resp.ConnectID] - t.connsLock.RUnlock() + conn, ok := t.conns.get(resp.ConnectID) if ok { timer := time.NewTimer((time.Duration)(t.readTimeoutSeconds) * time.Second) @@ -210,19 +295,16 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { } else { klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID) } + case client.PacketType_CLOSE_RSP: resp := pkt.GetCloseResponse() - t.connsLock.RLock() - conn, ok := t.conns[resp.ConnectID] - t.connsLock.RUnlock() + conn, ok := t.conns.get(resp.ConnectID) if ok { close(conn.readCh) conn.closeCh <- resp.Error close(conn.closeCh) - t.connsLock.Lock() - delete(t.conns, resp.ConnectID) - t.connsLock.Unlock() + t.conns.remove(resp.ConnectID) return } klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID) @@ -252,14 +334,8 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s // This channel MUST NOT be buffered. The sender needs to know when we are not receiving things, so they can abort. resCh := make(chan dialResult) - t.pendingDialLock.Lock() - t.pendingDial[random] = pendingDial{resultCh: resCh, cancelCh: cancelCh} - t.pendingDialLock.Unlock() - defer func() { - t.pendingDialLock.Lock() - delete(t.pendingDial, random) - t.pendingDialLock.Unlock() - }() + t.pendingDial.add(random, pendingDial{resultCh: resCh, cancelCh: cancelCh}) + defer t.pendingDial.remove(random) req := &client.Packet{ Type: client.PacketType_DIAL_REQ, @@ -280,25 +356,32 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s klog.V(5).Infoln("DIAL_REQ sent to proxy server") - c := &conn{stream: t.stream, random: random} + c := &conn{ + stream: t.stream, + random: random, + closeTunnel: t.closeTunnel, + } select { case res := <-resCh: - if res.err != "" { - return nil, &dialFailure{res.err, DialFailureEndpoint} + if res.err != nil { + return nil, res.err } c.connID = res.connid c.readCh = make(chan []byte, 10) c.closeCh = make(chan string, 1) - t.connsLock.Lock() - t.conns[res.connid] = c - t.connsLock.Unlock() + t.conns.add(res.connid, c) case <-time.After(30 * time.Second): klog.V(5).InfoS("Timed out waiting for DialResp", "dialID", random) + go t.closeDial(random) return nil, &dialFailure{"dial timeout, backstop", DialFailureTimeout} case <-requestCtx.Done(): klog.V(5).InfoS("Context canceled waiting for DialResp", "ctxErr", requestCtx.Err(), "dialID", random) + go t.closeDial(random) return nil, &dialFailure{"dial timeout, context", DialFailureContext} + case <-t.done: + klog.V(5).InfoS("Tunnel closed while waiting for DialResp", "dialID", random) + return nil, &dialFailure{"tunnel closed", DialFailureTunnelClosed} } return c, nil @@ -308,6 +391,31 @@ func (t *grpcTunnel) Done() <-chan struct{} { return t.done } +// Send a best-effort DIAL_CLS request for the given dial ID. +func (t *grpcTunnel) closeDial(dialID int64) { + req := &client.Packet{ + Type: client.PacketType_DIAL_CLS, + Payload: &client.Packet_CloseDial{ + CloseDial: &client.CloseDial{ + Random: dialID, + }, + }, + } + if err := t.stream.Send(req); err != nil { + klog.V(5).InfoS("Failed to send DIAL_CLS", "err", err, "dialID", dialID) + } + t.closeTunnel() +} + +func (t *grpcTunnel) closeTunnel() { + atomic.StoreUint32(&t.closing, 1) + t.clientConn.Close() +} + +func (t *grpcTunnel) isClosing() bool { + return atomic.LoadUint32(&t.closing) != 0 +} + func GetDialFailureReason(err error) (isDialFailure bool, reason DialFailureReason) { var df *dialFailure if errors.As(err, &df) { @@ -336,4 +444,10 @@ const ( DialFailureContext DialFailureReason = "context" // DialFailureEndpoint indicates that the konnectivity-agent was unable to reach the backend endpoint. DialFailureEndpoint DialFailureReason = "endpoint" + // DialFailureDialClosed indicates that the client received a CloseDial response, indicating the + // connection was closed before the dial could complete. + DialFailureDialClosed DialFailureReason = "dialclosed" + // DialFailureTunnelClosed indicates that the client connection was closed before the dial could + // complete. + DialFailureTunnelClosed DialFailureReason = "tunnelclosed" ) diff --git a/konnectivity-client/pkg/client/client_test.go b/konnectivity-client/pkg/client/client_test.go index bb46a8c1e..6bd7455b3 100644 --- a/konnectivity-client/pkg/client/client_test.go +++ b/konnectivity-client/pkg/client/client_test.go @@ -22,6 +22,7 @@ import ( "errors" "flag" "io" + "sync" "testing" "time" @@ -40,7 +41,7 @@ func TestMain(m *testing.M) { } func TestDial(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) ctx := context.Background() s, ps := pipe() @@ -49,9 +50,9 @@ func TestDial(t *testing.T) { defer ps.Close() defer s.Close() - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) - go tunnel.serve(ctx, &fakeConn{}) + go tunnel.serve(ctx) go ts.serve() _, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80") @@ -71,7 +72,7 @@ func TestDial(t *testing.T) { // TestDialRace exercises the scenario where serve() observes and handles DIAL_RSP // before DialContext() does any work after sending the DIAL_REQ. func TestDialRace(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) ctx := context.Background() s, ps := pipe() @@ -82,9 +83,9 @@ func TestDialRace(t *testing.T) { // artificially delay after calling Send, ensure handoff of result from serve to DialContext still works slowStream := fakeSlowSend{s} - tunnel := newUnstartedTunnel(slowStream) + tunnel := newUnstartedTunnel(slowStream, &fakeConn{}) - go tunnel.serve(ctx, &fakeConn{}) + go tunnel.serve(ctx) go ts.serve() _, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80") @@ -117,7 +118,7 @@ func (s fakeSlowSend) Send(p *client.Packet) error { } func TestData(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) ctx := context.Background() s, ps := pipe() @@ -126,9 +127,9 @@ func TestData(t *testing.T) { defer ps.Close() defer s.Close() - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) - go tunnel.serve(ctx, &fakeConn{}) + go tunnel.serve(ctx) go ts.serve() conn, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80") @@ -173,7 +174,7 @@ func TestData(t *testing.T) { } func TestClose(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) ctx := context.Background() s, ps := pipe() @@ -182,9 +183,9 @@ func TestClose(t *testing.T) { defer ps.Close() defer s.Close() - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) - go tunnel.serve(ctx, &fakeConn{}) + go tunnel.serve(ctx) go ts.serve() conn, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80") @@ -208,7 +209,7 @@ func TestCloseTimeout(t *testing.T) { if testing.Short() { t.Skip() } - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) ctx := context.Background() s, ps := pipe() @@ -223,9 +224,9 @@ func TestCloseTimeout(t *testing.T) { defer ps.Close() defer s.Close() - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) - go tunnel.serve(ctx, &fakeConn{}) + go tunnel.serve(ctx) go ts.serve() conn, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80") @@ -248,7 +249,7 @@ func TestCloseTimeout(t *testing.T) { } func TestCreateSingleUseGrpcTunnel_NoLeakOnFailure(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) tunnel, err := CreateSingleUseGrpcTunnel(context.Background(), "127.0.0.1:12345", grpc.WithInsecure()) if tunnel != nil { @@ -260,7 +261,7 @@ func TestCreateSingleUseGrpcTunnel_NoLeakOnFailure(t *testing.T) { } func TestCreateSingleUseGrpcTunnelWithContext_NoLeakOnFailure(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) tunnel, err := CreateSingleUseGrpcTunnelWithContext(context.Background(), context.Background(), "127.0.0.1:12345", grpc.WithInsecure()) if tunnel != nil { @@ -272,7 +273,7 @@ func TestCreateSingleUseGrpcTunnelWithContext_NoLeakOnFailure(t *testing.T) { } func TestDialAfterTunnelCancelled(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -282,9 +283,9 @@ func TestDialAfterTunnelCancelled(t *testing.T) { defer ps.Close() defer s.Close() - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) - go tunnel.serve(ctx, &fakeConn{}) + go tunnel.serve(ctx) go ts.serve() _, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80") @@ -292,51 +293,74 @@ func TestDialAfterTunnelCancelled(t *testing.T) { t.Fatalf("expect err when dialing after tunnel closed") } - t.Log("Wait for tunnel to close") select { case <-tunnel.Done(): - t.Log("Tunnel closed successfully") case <-time.After(30 * time.Second): t.Errorf("Timed out waiting for tunnel to close") } } func TestDial_RequestContextCancelled(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) - reqCtx, reqCancel := context.WithCancel(context.Background()) s, ps := pipe() + defer ps.Close() + defer s.Close() + ts := testServer(ps, 100) + reqCtx, reqCancel := context.WithCancel(context.Background()) ts.handlers[client.PacketType_DIAL_REQ] = func(*client.Packet) *client.Packet { reqCancel() return nil // don't respond } + closeCh := make(chan struct{}) + ts.handlers[client.PacketType_DIAL_CLS] = func(*client.Packet) *client.Packet { + close(closeCh) + return nil // don't respond + } + go ts.serve() - defer ps.Close() - defer s.Close() + func() { + // Tunnel should be shut down when the dial fails. + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) + go tunnel.serve(context.Background()) - go tunnel.serve(context.Background(), &fakeConn{}) - go ts.serve() + _, err := tunnel.DialContext(reqCtx, "tcp", "127.0.0.1:80") + if err == nil { + t.Fatalf("Expected dial error, got none") + } - _, err := tunnel.DialContext(reqCtx, "tcp", "127.0.0.1:80") - if err == nil { - t.Fatalf("Expected dial error, got none") - } + isDialFailure, reason := GetDialFailureReason(err) + if !isDialFailure { + t.Errorf("Unexpected non-dial failure error: %v", err) + } else if reason != DialFailureContext { + t.Errorf("Expected DialFailureContext, got %v", reason) + } - isDialFailure, reason := GetDialFailureReason(err) - if !isDialFailure { - t.Errorf("Unexpected non-dial failure error: %v", err) - } else if reason != DialFailureContext { - t.Errorf("Expected DialFailureContext, got %v", reason) - } + ts.assertPacketType(0, client.PacketType_DIAL_REQ) + waitForDialClsStart := time.Now() + select { + case <-closeCh: + t.Logf("Dial closed after %#v", time.Since(waitForDialClsStart).String()) + ts.assertPacketType(1, client.PacketType_DIAL_CLS) + case <-time.After(30 * time.Second): + t.Fatal("Timed out waiting for DIAL_CLS packet") + } - ts.assertPacketType(0, client.PacketType_DIAL_REQ) + waitForTunnelCloseStart := time.Now() + select { + case <-tunnel.Done(): + t.Logf("Tunnel closed after %#v", time.Since(waitForTunnelCloseStart).String()) + case <-time.After(30 * time.Second): + t.Errorf("Timed out waiting for tunnel to close") + } + }() } func TestDial_BackendError(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) s, ps := pipe() ts := testServer(ps, 100) @@ -355,9 +379,9 @@ func TestDial_BackendError(t *testing.T) { defer ps.Close() defer s.Close() - tunnel := newUnstartedTunnel(s) + tunnel := newUnstartedTunnel(s, s.conn()) - go tunnel.serve(context.Background(), &fakeConn{}) + go tunnel.serve(context.Background()) go ts.serve() _, err := tunnel.DialContext(context.Background(), "tcp", "127.0.0.1:80") @@ -369,12 +393,61 @@ func TestDial_BackendError(t *testing.T) { if !isDialFailure { t.Errorf("Unexpected non-dial failure error: %v", err) } else if reason != DialFailureEndpoint { - t.Errorf("Expected DialFailureContext, got %v", reason) + t.Errorf("Expected DialFailureEndpoint, got %v", reason) } ts.assertPacketType(0, client.PacketType_DIAL_REQ) } +func TestDial_Closed(t *testing.T) { + defer goleakVerifyNone(t, goleak.IgnoreCurrent()) + + s, ps := pipe() + defer ps.Close() + defer s.Close() + + ts := testServer(ps, 100) + ts.handlers[client.PacketType_DIAL_REQ] = func(pkt *client.Packet) *client.Packet { + return &client.Packet{ + Type: client.PacketType_DIAL_CLS, + Payload: &client.Packet_CloseDial{ + CloseDial: &client.CloseDial{ + Random: pkt.GetDialRequest().Random, + }, + }, + } + } + go ts.serve() + + func() { + // Verify that the tunnel goroutines are not leaked before cleaning up the test server. + goleakVerifyNone(t, goleak.IgnoreCurrent()) + + tunnel := newUnstartedTunnel(s, s.conn()) + go tunnel.serve(context.Background()) + + _, err := tunnel.DialContext(context.Background(), "tcp", "127.0.0.1:80") + if err == nil { + t.Fatalf("Expected dial error, got none") + } + + isDialFailure, reason := GetDialFailureReason(err) + if !isDialFailure { + t.Errorf("Unexpected non-dial failure error: %v", err) + } else if reason != DialFailureDialClosed { + t.Errorf("Expected DialFailureDialClosed, got %v", reason) + } + + ts.assertPacketType(0, client.PacketType_DIAL_REQ) + + select { + case <-tunnel.Done(): + case <-time.After(30 * time.Second): + t.Errorf("Timed out waiting for tunnel to close") + } + }() +} + // TODO: Move to common testing library // fakeStream implements ProxyService_ProxyClient @@ -387,9 +460,13 @@ type fakeStream struct { } type fakeConn struct { + stream *fakeStream } func (f *fakeConn) Close() error { + if f.stream != nil { + f.stream.Close() + } return nil } @@ -434,13 +511,21 @@ func (s *fakeStream) Recv() (*client.Packet, error) { case pkt := <-s.r: klog.V(4).InfoS("[DEBUG] recv", "packet", pkt) return pkt, nil - case <-time.After(5 * time.Second): + case <-time.After(30 * time.Second): return nil, errors.New("timeout recv") } } func (s *fakeStream) Close() { - close(s.closed) + select { + case <-s.closed: // Avoid double-closing + default: + close(s.closed) + } +} + +func (s *fakeStream) conn() *fakeConn { + return &fakeConn{s} } type proxyServer struct { @@ -449,7 +534,9 @@ type proxyServer struct { handlers map[client.PacketType]handler connid int64 data bytes.Buffer - packets []*client.Packet + + packets []*client.Packet + packetsLock sync.Mutex } func testServer(s client.ProxyService_ProxyClient, connid int64) *proxyServer { @@ -479,7 +566,11 @@ func (s *proxyServer) serve() { return } - s.packets = append(s.packets, pkt) + func() { + s.packetsLock.Lock() + defer s.packetsLock.Unlock() + s.packets = append(s.packets, pkt) + }() if handler, ok := s.handlers[pkt.Type]; ok { req := handler(pkt) @@ -493,6 +584,9 @@ func (s *proxyServer) serve() { } func (s *proxyServer) assertPacketType(index int, expectedType client.PacketType) { + s.packetsLock.Lock() + defer s.packetsLock.Unlock() + if index >= len(s.packets) { s.t.Fatalf("Expected %v packet[%d], but have only received %d packets", expectedType, index, len(s.packets)) } @@ -540,3 +634,13 @@ func (s *proxyServer) handleData(pkt *client.Packet) *client.Packet { }, } } + +// Override goleakVerifyNone to set t.Helper. +// TODO: delete this once goleak has been updated to include +// https://github.com/uber-go/goleak/commit/2dfebe88ddf19de216c4ab15a1189fc640b1ea9f +func goleakVerifyNone(t *testing.T, options ...goleak.Option) { + t.Helper() + if err := goleak.Find(options...); err != nil { + t.Error(err) + } +} diff --git a/konnectivity-client/pkg/client/conn.go b/konnectivity-client/pkg/client/conn.go index 822831b10..f76b1e37a 100644 --- a/konnectivity-client/pkg/client/conn.go +++ b/konnectivity-client/pkg/client/conn.go @@ -41,6 +41,9 @@ type conn struct { readCh chan []byte closeCh chan string rdata []byte + + // closeTunnel is an optional callback to close the underlying grpc connection. + closeTunnel func() } var _ net.Conn = &conn{} @@ -116,6 +119,10 @@ func (c *conn) SetWriteDeadline(t time.Time) error { // proxy service to notify remote to drop the connection. func (c *conn) Close() error { klog.V(4).Infoln("closing connection") + if c.closeTunnel != nil { + defer c.closeTunnel() + } + var req *client.Packet if c.connID != 0 { req = &client.Packet{ diff --git a/tests/proxy_test.go b/tests/proxy_test.go index 0bc835e72..5fbadf5b1 100644 --- a/tests/proxy_test.go +++ b/tests/proxy_test.go @@ -17,12 +17,14 @@ import ( "github.com/google/uuid" "go.uber.org/goleak" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client" clientproto "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client" "sigs.k8s.io/apiserver-network-proxy/pkg/agent" "sigs.k8s.io/apiserver-network-proxy/pkg/server" agentproto "sigs.k8s.io/apiserver-network-proxy/proto/agent" + "sigs.k8s.io/apiserver-network-proxy/proto/header" ) // Define a blackholed address, for which Dial is expected to hang. This address is reserved for @@ -268,10 +270,12 @@ func TestProxyDial_RequestCancelled_GRPC(t *testing.T) { } defer cleanup() - stopCh := make(chan struct{}) - defer close(stopCh) - clientset := runAgent(proxy.agent, stopCh) - waitForConnectedServerCount(t, 1, clientset) + agent := &unresponsiveAgent{} + if err := agent.Connect(proxy.agent); err != nil { + t.Fatalf("Failed to connect unresponsive agent: %v", err) + } + defer agent.Close() + waitForConnectedAgentCount(t, 1, proxy.server) func() { // Ensure that tunnels aren't leaked with long-running servers. @@ -286,8 +290,6 @@ func TestProxyDial_RequestCancelled_GRPC(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - // Note: We need to wait long enough for the DialContext to be sent, but the agent uses - // a hard-coded 5 second timeout. This is likely to be flaky. time.Sleep(1 * time.Second) cancel() // Cancel the request (client-side) }() @@ -295,14 +297,52 @@ func TestProxyDial_RequestCancelled_GRPC(t *testing.T) { _, err = tunnel.DialContext(ctx, "tcp", blackhole) if err == nil { t.Error("Expected error when context is cancelled, did not receive error") - } else if !strings.Contains(err.Error(), "dial timeout, context") { + } else if _, reason := client.GetDialFailureReason(err); reason != client.DialFailureContext { + t.Errorf("Unexpected error: %v", err) + } + + select { + case <-tunnel.Done(): + case <-time.After(wait.ForeverTestTimeout): + t.Errorf("Timed out waiting for tunnel to close") + } + }() +} + +func TestProxyDial_AgentTimeout_GRPC(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + proxy, cleanup, err := runGRPCProxyServer() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + stopCh := make(chan struct{}) + defer close(stopCh) + clientset := runAgent(proxy.agent, stopCh) + waitForConnectedServerCount(t, 1, clientset) + + func() { + // Ensure that tunnels aren't leaked with long-running servers. + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + // run test client + tunnel, err := client.CreateSingleUseGrpcTunnel(context.Background(), proxy.front, grpc.WithInsecure()) + if err != nil { + t.Fatal(err) + } + + // Agent should time out after 5 seconds and return a DIAL_RSP with an error. + _, err = tunnel.DialContext(context.Background(), "tcp", blackhole) + if err == nil { + t.Error("Expected error when context is cancelled, did not receive error") + } else if _, reason := client.GetDialFailureReason(err); reason != client.DialFailureEndpoint { t.Errorf("Unexpected error: %v", err) } - t.Log("Wait for tunnel to close") select { case <-tunnel.Done(): - t.Log("Tunnel closed successfully") case <-time.After(wait.ForeverTestTimeout): t.Errorf("Timed out waiting for tunnel to close") } @@ -679,6 +719,33 @@ func runAgentWithID(agentID, addr string, stopCh <-chan struct{}) *agent.ClientS return client } +type unresponsiveAgent struct { + conn *grpc.ClientConn +} + +// Connect registers the unresponsive agent with the proxy server. +func (a *unresponsiveAgent) Connect(address string) error { + agentID := uuid.New().String() + conn, err := grpc.Dial(address, grpc.WithInsecure()) + if err != nil { + return err + } + ctx := metadata.AppendToOutgoingContext(context.Background(), + header.AgentID, agentID) + _, err = agentproto.NewAgentServiceClient(conn).Connect(ctx) + if err != nil { + conn.Close() + return err + } + + a.conn = conn + return nil +} + +func (a *unresponsiveAgent) Close() { + a.conn.Close() +} + // waitForConnectedServerCount waits for the agent ClientSet to have the expected number of health // server connections (HealthyClientsCount). func waitForConnectedServerCount(t *testing.T, expectedServerCount int, clientset *agent.ClientSet) {