-
Notifications
You must be signed in to change notification settings - Fork 208
[konnectivity-client] Ensure grpc tunnel is closed on dial failure #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c5ec7c9
41ceca0
866231b
6539562
529507c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import ( | |
| "math/rand" | ||
| "net" | ||
| "sync" | ||
| "sync/atomic" | ||
| "time" | ||
|
|
||
| "google.golang.org/grpc" | ||
|
|
@@ -42,7 +43,7 @@ type Tunnel interface { | |
| } | ||
|
|
||
| type dialResult struct { | ||
| err string | ||
| err *dialFailure | ||
| connid int64 | ||
| } | ||
|
|
||
|
|
@@ -53,13 +54,70 @@ type pendingDial struct { | |
| cancelCh <-chan struct{} | ||
| } | ||
|
|
||
| // TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+ | ||
| type pendingDialManager struct { | ||
| pendingDials map[int64]pendingDial | ||
| mutex sync.RWMutex | ||
| } | ||
|
|
||
| func (p *pendingDialManager) add(dialID int64, pd pendingDial) { | ||
| p.mutex.Lock() | ||
| defer p.mutex.Unlock() | ||
| p.pendingDials[dialID] = pd | ||
| } | ||
|
|
||
| func (p *pendingDialManager) remove(dialID int64) { | ||
| p.mutex.Lock() | ||
| defer p.mutex.Unlock() | ||
| delete(p.pendingDials, dialID) | ||
| } | ||
|
|
||
| func (p *pendingDialManager) get(dialID int64) (pendingDial, bool) { | ||
| p.mutex.RLock() | ||
| defer p.mutex.RUnlock() | ||
| pd, ok := p.pendingDials[dialID] | ||
| return pd, ok | ||
| } | ||
|
|
||
| // TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+ | ||
| type connectionManager struct { | ||
| conns map[int64]*conn | ||
| mutex sync.RWMutex | ||
| } | ||
|
|
||
| func (cm *connectionManager) add(connID int64, c *conn) { | ||
| cm.mutex.Lock() | ||
| defer cm.mutex.Unlock() | ||
| cm.conns[connID] = c | ||
| } | ||
|
|
||
| func (cm *connectionManager) remove(connID int64) { | ||
| cm.mutex.Lock() | ||
| defer cm.mutex.Unlock() | ||
| delete(cm.conns, connID) | ||
| } | ||
|
|
||
| func (cm *connectionManager) get(connID int64) (*conn, bool) { | ||
| cm.mutex.RLock() | ||
| defer cm.mutex.RUnlock() | ||
| c, ok := cm.conns[connID] | ||
| return c, ok | ||
| } | ||
|
|
||
| func (cm *connectionManager) closeAll() { | ||
| cm.mutex.Lock() | ||
| defer cm.mutex.Unlock() | ||
| for _, conn := range cm.conns { | ||
| close(conn.readCh) | ||
| } | ||
| } | ||
|
|
||
| // grpcTunnel implements Tunnel | ||
| type grpcTunnel struct { | ||
| stream client.ProxyService_ProxyClient | ||
| pendingDial map[int64]pendingDial | ||
| conns map[int64]*conn | ||
| pendingDialLock sync.RWMutex | ||
| connsLock sync.RWMutex | ||
| stream client.ProxyService_ProxyClient | ||
| clientConn clientConn | ||
| pendingDial pendingDialManager | ||
| conns connectionManager | ||
|
|
||
| // The tunnel will be closed if the caller fails to read via conn.Read() | ||
| // more than readTimeoutSeconds after a packet has been received. | ||
|
|
@@ -68,6 +126,11 @@ type grpcTunnel struct { | |
| // The done channel is closed after the tunnel has cleaned up all connections and is no longer | ||
| // serving. | ||
| done chan struct{} | ||
|
|
||
| // closing is an atomic bool represented as a 0 or 1, and set to true when the tunnel is being closed. | ||
| // closing should only be accessed through atomic methods. | ||
| // TODO: switch this to an atomic.Bool once the client is exclusively buit with go1.19+ | ||
| closing uint32 | ||
| } | ||
|
|
||
| type clientConn interface { | ||
|
|
@@ -106,42 +169,39 @@ func CreateSingleUseGrpcTunnelWithContext(createCtx, tunnelCtx context.Context, | |
| return nil, err | ||
| } | ||
|
|
||
| tunnel := newUnstartedTunnel(stream) | ||
| tunnel := newUnstartedTunnel(stream, c) | ||
|
|
||
| go tunnel.serve(tunnelCtx, c) | ||
| go tunnel.serve(tunnelCtx) | ||
|
|
||
| return tunnel, nil | ||
| } | ||
|
|
||
| func newUnstartedTunnel(stream client.ProxyService_ProxyClient) *grpcTunnel { | ||
| func newUnstartedTunnel(stream client.ProxyService_ProxyClient, c clientConn) *grpcTunnel { | ||
| return &grpcTunnel{ | ||
| stream: stream, | ||
| pendingDial: make(map[int64]pendingDial), | ||
| conns: make(map[int64]*conn), | ||
| clientConn: c, | ||
| pendingDial: pendingDialManager{pendingDials: make(map[int64]pendingDial)}, | ||
| conns: connectionManager{conns: make(map[int64]*conn)}, | ||
| readTimeoutSeconds: 10, | ||
| done: make(chan struct{}), | ||
| } | ||
| } | ||
|
|
||
| func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { | ||
| func (t *grpcTunnel) serve(tunnelCtx context.Context) { | ||
| defer func() { | ||
| c.Close() | ||
| t.clientConn.Close() | ||
|
|
||
| // A connection in t.conns after serve() returns means | ||
| // we never received a CLOSE_RSP for it, so we need to | ||
| // close any channels remaining for these connections. | ||
| t.connsLock.Lock() | ||
| for _, conn := range t.conns { | ||
| close(conn.readCh) | ||
| } | ||
| t.connsLock.Unlock() | ||
| t.conns.closeAll() | ||
|
|
||
| close(t.done) | ||
| }() | ||
|
|
||
| for { | ||
| pkt, err := t.stream.Recv() | ||
| if err == io.EOF { | ||
| if err == io.EOF || t.isClosing() { | ||
cheftako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return | ||
| } | ||
| if err != nil || pkt == nil { | ||
|
|
@@ -154,28 +214,29 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { | |
| switch pkt.Type { | ||
| case client.PacketType_DIAL_RSP: | ||
| resp := pkt.GetDialResponse() | ||
| t.pendingDialLock.RLock() | ||
| pendingDial, ok := t.pendingDial[resp.Random] | ||
| t.pendingDialLock.RUnlock() | ||
| pendingDial, ok := t.pendingDial.get(resp.Random) | ||
|
|
||
| if !ok { | ||
| // If the DIAL_RSP does not match a pending dial, it means one of two things: | ||
| // 1. There was a second DIAL_RSP for the connection request (this is very unlikely but possible) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should ever get two DIAL_RSP for a pending dial. Collisions shouldn't be possible given the client sets req.Random. Maybe this could happen if the server had a race between sending the DIAL_RESP and a DIAL_CLS on something like a timeout?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Almost the inverse of #1 below.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I was also confused by this comment. The comment was there before (moved to here), but I can't think of when this would occur. I'll update it. |
||
| // 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context | ||
| // | ||
| // In either scenario, we should return here and close the tunnel as it is no longer needed. | ||
| klog.V(1).InfoS("DialResp not recognized; dropped", "connectionID", resp.ConnectID, "dialID", resp.Random) | ||
| return | ||
| } else { | ||
| result := dialResult{ | ||
| err: resp.Error, | ||
| connid: resp.ConnectID, | ||
| result := dialResult{connid: resp.ConnectID} | ||
| if resp.Error != "" { | ||
| result.err = &dialFailure{resp.Error, DialFailureEndpoint} | ||
| } | ||
| select { | ||
| // try to send to the result channel | ||
| case pendingDial.resultCh <- result: | ||
| // unblock if the cancel channel is closed | ||
| case <-pendingDial.cancelCh: | ||
| // If there are no readers of the pending dial channel above, it means one of two things: | ||
| // 1. There was a second DIAL_RSP for the connection request (this is very unlikely but possible) | ||
| // 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context | ||
| // | ||
| // In either scenario, we should return here as this tunnel is no longer needed. | ||
| // Note: this condition can only be hit by a race condition where the | ||
| // DialContext() returns early (timeout) after the pendingDial is already | ||
| // fetched here, but before the result is sent. | ||
| klog.V(1).InfoS("Pending dial has been cancelled; dropped", "connectionID", resp.ConnectID, "dialID", resp.Random) | ||
| return | ||
| case <-tunnelCtx.Done(): | ||
|
|
@@ -189,12 +250,36 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { | |
| return | ||
| } | ||
|
|
||
| case client.PacketType_DIAL_CLS: | ||
| resp := pkt.GetCloseDial() | ||
| pendingDial, ok := t.pendingDial.get(resp.Random) | ||
|
|
||
| if !ok { | ||
| // If the DIAL_CLS does not match a pending dial, it means one of two things: | ||
| // 1. There was a DIAL_CLS receieved after a DIAL_RSP (unlikely but possible) | ||
| // 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context | ||
| // | ||
| // In either scenario, we should return here and close the tunnel as it is no longer needed. | ||
| klog.V(1).InfoS("DIAL_CLS after dial finished", "dialID", resp.Random) | ||
| } else { | ||
| result := dialResult{ | ||
| err: &dialFailure{"dial closed", DialFailureDialClosed}, | ||
| } | ||
| select { | ||
| case pendingDial.resultCh <- result: | ||
| case <-pendingDial.cancelCh: | ||
| // Note: this condition can only be hit by a race condition where the | ||
| // DialContext() returns early (timeout) after the pendingDial is already | ||
| // fetched here, but before the result is sent. | ||
| case <-tunnelCtx.Done(): | ||
| } | ||
| } | ||
| return // Stop serving & close the tunnel. | ||
|
|
||
| case client.PacketType_DATA: | ||
| resp := pkt.GetData() | ||
| // TODO: flow control | ||
| t.connsLock.RLock() | ||
| conn, ok := t.conns[resp.ConnectID] | ||
| t.connsLock.RUnlock() | ||
| conn, ok := t.conns.get(resp.ConnectID) | ||
|
|
||
| if ok { | ||
| timer := time.NewTimer((time.Duration)(t.readTimeoutSeconds) * time.Second) | ||
|
|
@@ -210,19 +295,16 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) { | |
| } else { | ||
| klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID) | ||
| } | ||
|
|
||
| case client.PacketType_CLOSE_RSP: | ||
| resp := pkt.GetCloseResponse() | ||
| t.connsLock.RLock() | ||
| conn, ok := t.conns[resp.ConnectID] | ||
| t.connsLock.RUnlock() | ||
| conn, ok := t.conns.get(resp.ConnectID) | ||
|
|
||
| if ok { | ||
| close(conn.readCh) | ||
| conn.closeCh <- resp.Error | ||
| close(conn.closeCh) | ||
| t.connsLock.Lock() | ||
| delete(t.conns, resp.ConnectID) | ||
| t.connsLock.Unlock() | ||
| t.conns.remove(resp.ConnectID) | ||
| return | ||
| } | ||
| klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID) | ||
|
|
@@ -252,14 +334,8 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s | |
| // This channel MUST NOT be buffered. The sender needs to know when we are not receiving things, so they can abort. | ||
| resCh := make(chan dialResult) | ||
|
|
||
| t.pendingDialLock.Lock() | ||
| t.pendingDial[random] = pendingDial{resultCh: resCh, cancelCh: cancelCh} | ||
| t.pendingDialLock.Unlock() | ||
| defer func() { | ||
| t.pendingDialLock.Lock() | ||
| delete(t.pendingDial, random) | ||
| t.pendingDialLock.Unlock() | ||
| }() | ||
| t.pendingDial.add(random, pendingDial{resultCh: resCh, cancelCh: cancelCh}) | ||
| defer t.pendingDial.remove(random) | ||
|
|
||
| req := &client.Packet{ | ||
| Type: client.PacketType_DIAL_REQ, | ||
|
|
@@ -280,25 +356,32 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s | |
|
|
||
| klog.V(5).Infoln("DIAL_REQ sent to proxy server") | ||
|
|
||
| c := &conn{stream: t.stream, random: random} | ||
| c := &conn{ | ||
| stream: t.stream, | ||
| random: random, | ||
| closeTunnel: t.closeTunnel, | ||
| } | ||
|
|
||
| select { | ||
| case res := <-resCh: | ||
| if res.err != "" { | ||
| return nil, &dialFailure{res.err, DialFailureEndpoint} | ||
| if res.err != nil { | ||
| return nil, res.err | ||
| } | ||
| c.connID = res.connid | ||
| c.readCh = make(chan []byte, 10) | ||
| c.closeCh = make(chan string, 1) | ||
| t.connsLock.Lock() | ||
| t.conns[res.connid] = c | ||
| t.connsLock.Unlock() | ||
| t.conns.add(res.connid, c) | ||
| case <-time.After(30 * time.Second): | ||
| klog.V(5).InfoS("Timed out waiting for DialResp", "dialID", random) | ||
| go t.closeDial(random) | ||
| return nil, &dialFailure{"dial timeout, backstop", DialFailureTimeout} | ||
| case <-requestCtx.Done(): | ||
| klog.V(5).InfoS("Context canceled waiting for DialResp", "ctxErr", requestCtx.Err(), "dialID", random) | ||
| go t.closeDial(random) | ||
| return nil, &dialFailure{"dial timeout, context", DialFailureContext} | ||
| case <-t.done: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need 3 paths for we're shutting down?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately yes, I don't see a clean way to consolidate these. The timeout & context canceled cases could easily be collapsed, but then we wouldn't be able to distinguish between them in the logs. This |
||
| klog.V(5).InfoS("Tunnel closed while waiting for DialResp", "dialID", random) | ||
| return nil, &dialFailure{"tunnel closed", DialFailureTunnelClosed} | ||
| } | ||
|
|
||
| return c, nil | ||
|
|
@@ -308,6 +391,31 @@ func (t *grpcTunnel) Done() <-chan struct{} { | |
| return t.done | ||
| } | ||
|
|
||
| // Send a best-effort DIAL_CLS request for the given dial ID. | ||
| func (t *grpcTunnel) closeDial(dialID int64) { | ||
tallclair marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| req := &client.Packet{ | ||
| Type: client.PacketType_DIAL_CLS, | ||
| Payload: &client.Packet_CloseDial{ | ||
| CloseDial: &client.CloseDial{ | ||
| Random: dialID, | ||
| }, | ||
| }, | ||
| } | ||
| if err := t.stream.Send(req); err != nil { | ||
| klog.V(5).InfoS("Failed to send DIAL_CLS", "err", err, "dialID", dialID) | ||
| } | ||
| t.closeTunnel() | ||
| } | ||
|
|
||
| func (t *grpcTunnel) closeTunnel() { | ||
| atomic.StoreUint32(&t.closing, 1) | ||
tallclair marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| t.clientConn.Close() | ||
| } | ||
|
|
||
| func (t *grpcTunnel) isClosing() bool { | ||
| return atomic.LoadUint32(&t.closing) != 0 | ||
| } | ||
|
|
||
| func GetDialFailureReason(err error) (isDialFailure bool, reason DialFailureReason) { | ||
| var df *dialFailure | ||
| if errors.As(err, &df) { | ||
|
|
@@ -336,4 +444,10 @@ const ( | |
| DialFailureContext DialFailureReason = "context" | ||
| // DialFailureEndpoint indicates that the konnectivity-agent was unable to reach the backend endpoint. | ||
| DialFailureEndpoint DialFailureReason = "endpoint" | ||
| // DialFailureDialClosed indicates that the client received a CloseDial response, indicating the | ||
| // connection was closed before the dial could complete. | ||
| DialFailureDialClosed DialFailureReason = "dialclosed" | ||
| // DialFailureTunnelClosed indicates that the client connection was closed before the dial could | ||
| // complete. | ||
| DialFailureTunnelClosed DialFailureReason = "tunnelclosed" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.