diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 358e8b9a5993b..670ad255dfc4a 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -38,7 +38,6 @@ import ( "github.com/gravitational/teleport/api/defaults" transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" "github.com/gravitational/teleport/api/metadata" - "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/utils/grpc/interceptors" ) @@ -228,11 +227,9 @@ type clusterCredentials struct { clusterName *clusterName } -var ( - // teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding - // origin teleport cluster name into certificates. - teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7} -) +// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding +// origin teleport cluster name into certificates. +var teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7} // ClientHandshake performs the handshake with the wrapped [credentials.TransportCredentials] and // then inspects the provided cert for the [teleportClusterASN1ExtensionOID] to determine @@ -399,18 +396,7 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) (client.Confi CircuitBreakerConfig: breaker.NoopBreakerConfig(), DialInBackground: true, Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) { - // Don't dial if the context has timed out. - select { - case <-dialCtx.Done(): - return nil, dialCtx.Err() - default: - } - - // Intentionally not using the dial context because it is only valid - // for the lifetime of the dial. Using it causes the stream to be terminated - // immediately after the dial completes. - connContext := tracing.WithPropagationContext(context.Background(), tracing.PropagationContextFromContext(dialCtx)) - conn, err := c.transport.DialCluster(connContext, cluster, nil) + conn, err := c.transport.DialCluster(dialCtx, cluster, nil) return conn, trace.Wrap(err) }), DialOpts: c.cfg.DialOpts, diff --git a/api/client/proxy/transport/transportv1/client.go b/api/client/proxy/transport/transportv1/client.go index ed2ee09a97cae..39a21874fe071 100644 --- a/api/client/proxy/transport/transportv1/client.go +++ b/api/client/proxy/transport/transportv1/client.go @@ -59,22 +59,37 @@ func (c *Client) ClusterDetails(ctx context.Context) (*transportv1pb.ClusterDeta // DialCluster establishes a connection to the provided cluster. The provided // src address will be used as the LocalAddr of the returned [net.Conn]. func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr) (net.Conn, error) { - stream, err := c.clt.ProxyCluster(ctx) + // we do this rather than using context.Background to inherit any OTEL data + // from the dial context + connCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) + stop := context.AfterFunc(ctx, cancel) + defer stop() + + stream, err := c.clt.ProxyCluster(connCtx) if err != nil { + cancel() return nil, trace.Wrap(err, "unable to establish proxy stream") } if err := stream.Send(&transportv1pb.ProxyClusterRequest{Cluster: cluster}); err != nil { + cancel() return nil, trace.Wrap(err, "failed to send cluster request") } - streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream}) + if !stop() { + cancel() + return nil, trace.Wrap(connCtx.Err(), "unable to establish proxy stream") + } + + streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream, cancel: cancel}) if err != nil { + cancel() return nil, trace.Wrap(err, "unable to create stream reader") } p, ok := peer.FromContext(stream.Context()) if !ok { + streamRW.Close() return nil, trace.BadParameter("unable to retrieve peer information") } @@ -85,6 +100,7 @@ func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr) // for a [transportv1pb.TransportService_ProxyClusterClient]. type clusterStream struct { stream transportv1pb.TransportService_ProxyClusterClient + cancel context.CancelFunc } func (c clusterStream) Recv() ([]byte, error) { @@ -105,7 +121,10 @@ func (c clusterStream) Send(frame []byte) error { } func (c clusterStream) Close() error { - return trace.Wrap(c.stream.CloseSend()) + if c.cancel != nil { + c.cancel() + } + return nil } // DialHost establishes a connection to the instance in the provided cluster that matches diff --git a/api/utils/grpc/stream/stream.go b/api/utils/grpc/stream/stream.go index b1af771379313..7fe4694da7954 100644 --- a/api/utils/grpc/stream/stream.go +++ b/api/utils/grpc/stream/stream.go @@ -128,8 +128,6 @@ func (c *ReadWriter) Write(b []byte) (int, error) { // Close cleans up resources used by the stream. func (c *ReadWriter) Close() error { if cs, ok := c.source.(io.Closer); ok { - c.wLock.Lock() - defer c.wLock.Unlock() return trace.Wrap(cs.Close()) } diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index 3fe4262c9ab89..a9c3463258164 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -157,12 +157,57 @@ func (c *ClientConfig) checkAndSetDefaults() error { // clientConn hold info about a dialed grpc connection type clientConn struct { *grpc.ClientConn - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup id string addr string + + // if closing is set, count is not allowed to increase from zero; upon + // reaching zero, cond should be broadcast + mu sync.Mutex + cond sync.Cond + closing bool + count int +} + +func (c *clientConn) maybeAcquire() (release func()) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closing && c.count < 1 { + return nil + } + c.count++ + + return sync.OnceFunc(func() { + c.mu.Lock() + defer c.mu.Unlock() + c.count-- + if c.count == 0 { + c.cond.Broadcast() + } + }) +} + +// Shutdown closes the clientConn after all connections through it are closed, +// or after the context is done. +func (c *clientConn) Shutdown(ctx context.Context) { + defer c.Close() + + c.mu.Lock() + defer c.mu.Unlock() + + c.closing = true + if c.count == 0 { + return + } + + if c.cond.L == nil { + c.cond.L = &c.mu + } + defer context.AfterFunc(ctx, c.cond.Broadcast)() + for c.count > 0 && ctx.Err() == nil { + c.cond.Wait() + } } // Client is a peer proxy service client using grpc and tls. @@ -335,7 +380,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { for _, id := range toDelete { if conn, ok := c.conns[id]; ok { - go c.shutdownConn(conn) + go conn.Shutdown(c.ctx) } } c.conns = toKeep @@ -367,8 +412,9 @@ func (c *Client) DialNode( return nil, trace.ConnectionProblem(err, "error dialing peer proxies %s: %v", proxyIDs, err) } - streamRW, err := streamutils.NewReadWriter(frameStream{stream: stream}) + streamRW, err := streamutils.NewReadWriter(stream) if err != nil { + _ = stream.Close() return nil, trace.Wrap(err) } @@ -385,6 +431,7 @@ type stream interface { // frameStream implements [streamutils.Source]. type frameStream struct { stream stream + cancel context.CancelFunc } func (s frameStream) Send(p []byte) error { @@ -405,15 +452,14 @@ func (s frameStream) Recv() ([]byte, error) { } func (s frameStream) Close() error { - if cs, ok := s.stream.(grpc.ClientStream); ok { - return trace.Wrap(cs.CloseSend()) + if s.cancel != nil { + s.cancel() } - return nil } // Shutdown gracefully shuts down all existing client connections. -func (c *Client) Shutdown() { +func (c *Client) Shutdown(ctx context.Context) { c.Lock() defer c.Unlock() @@ -422,23 +468,7 @@ func (c *Client) Shutdown() { wg.Add(1) go func(conn *clientConn) { defer wg.Done() - - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.config.GracefulShutdownTimeout) - defer cancel() - - go func() { - if err := c.shutdownConn(conn); err != nil { - c.config.Log.Infof("proxy peer connection %+v graceful shutdown error: %+v", conn.id, err) - } - }() - - select { - case <-conn.ctx.Done(): - case <-timeoutCtx.Done(): - if err := c.stopConn(conn); err != nil { - c.config.Log.Infof("proxy peer connection %+v close error: %+v", conn.id, err) - } - } + conn.Shutdown(ctx) }(conn) } wg.Wait() @@ -452,7 +482,7 @@ func (c *Client) Stop() error { var errs []error for _, conn := range c.conns { - if err := c.stopConn(conn); err != nil { + if err := conn.Close(); err != nil { errs = append(errs, err) } } @@ -466,67 +496,67 @@ func (c *Client) GetConnectionsCount() int { return len(c.conns) } -// shutdownConn gracefully shuts down a clientConn -// by waiting for open streams to finish. -func (c *Client) shutdownConn(conn *clientConn) error { - conn.wg.Wait() // wait for streams to gracefully end - conn.cancel() - return conn.Close() -} - -// stopConn immediately closes a clientConn -func (c *Client) stopConn(conn *clientConn) error { - conn.cancel() - return conn.Close() -} - // dial opens a new stream to one of the supplied proxy ids. // it tries to find an existing grpc.ClientConn or initializes a new rpc // to one of the proxies otherwise. // The boolean returned in the second argument is intended for testing purposes, // to indicates whether the connection was cached or newly established. -func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (clientapi.ProxyService_DialNodeClient, bool, error) { +func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (frameStream, bool, error) { conns, existing, err := c.getConnections(proxyIDs) if err != nil { - return nil, existing, trace.Wrap(err) + return frameStream{}, existing, trace.Wrap(err) } var errs []error for _, conn := range conns { - stream, err := c.startStream(conn) + release := conn.maybeAcquire() + if release == nil { + c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) + errs = append(errs, trace.ConnectionProblem(nil, "error starting stream: connection is shutting down")) + continue + } + + ctx, cancel := context.WithCancel(context.Background()) + context.AfterFunc(ctx, release) + + stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) if err != nil { + cancel() c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) c.config.Log.Debugf("Error opening tunnel rpc to proxy %+v at %+v", conn.id, conn.addr) errs = append(errs, trace.ConnectionProblem(err, "error starting stream: %v", err)) continue } + err = stream.Send(&clientapi.Frame{ Message: &clientapi.Frame_DialRequest{ DialRequest: dialRequest, }, }) if err != nil { + cancel() errs = append(errs, trace.ConnectionProblem(err, "error sending dial frame: %v", err)) continue } msg, err := stream.Recv() if err != nil { + cancel() errs = append(errs, trace.ConnectionProblem(err, "error receiving dial response: %v", err)) continue } if msg.GetConnectionEstablished() == nil { - err := stream.CloseSend() - if err != nil { - c.config.Log.Debugf("error closing stream: %w", err) - } + cancel() errs = append(errs, trace.ConnectionProblem(nil, "received malformed connection established frame")) continue } - return stream, existing, nil + return frameStream{ + stream: stream, + cancel: cancel, + }, existing, nil } - return nil, existing, trace.NewAggregate(errs...) + return frameStream{}, existing, trace.NewAggregate(errs...) } // getConnections returns connections to the supplied proxy ids. @@ -611,17 +641,13 @@ func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) { return nil, trace.Wrap(err, "Error updating client tls config") } - connCtx, cancel := context.WithCancel(c.ctx) - wg := new(sync.WaitGroup) - expectedPeer := auth.HostFQDN(peerID, c.config.ClusterName) - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.Dial( peerAddr, grpc.WithTransportCredentials(newClientCredentials(expectedPeer, peerAddr, c.config.Log, credentials.NewTLS(tlsConfig))), grpc.WithStatsHandler(newStatsHandler(c.reporter)), - grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor, streamCounterInterceptor(wg)), + grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: peerKeepAlive, Timeout: peerTimeout, @@ -630,28 +656,13 @@ func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) { grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`), ) if err != nil { - cancel() return nil, trace.Wrap(err, "Error dialing proxy %q", peerID) } return &clientConn{ ClientConn: conn, - ctx: connCtx, - cancel: cancel, - wg: wg, - id: peerID, - addr: peerAddr, - }, nil -} -// startStream opens a new stream to the provided connection. -func (c *Client) startStream(conn *clientConn) (clientapi.ProxyService_DialNodeClient, error) { - client := clientapi.NewProxyServiceClient(conn.ClientConn) - - stream, err := client.DialNode(conn.ctx) - if err != nil { - return nil, trace.Wrap(err, "Error opening stream to proxy %+v", conn.id) - } - - return stream, nil + id: peerID, + addr: peerAddr, + }, nil } diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index 93e6931fec6ae..017eeb8b3f4dd 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -19,6 +19,7 @@ package peer import ( + "context" "crypto/tls" "crypto/x509" "testing" @@ -29,6 +30,7 @@ import ( "google.golang.org/grpc/connectivity" "github.com/gravitational/teleport/api/client/proto" + clientapi "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" ) @@ -49,22 +51,22 @@ func TestClientConn(t *testing.T) { stream, cached, err := client.dial([]string{"s1"}, &proto.DialRequest{}) require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream) - stream.CloseSend() + require.NotNil(t, stream.stream) + stream.Close() // dial second server stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream) - stream.CloseSend() + require.NotNil(t, stream.stream) + stream.Close() // redial second server stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream) - stream.CloseSend() + require.NotNil(t, stream.stream) + stream.Close() // close second server // and attempt to redial it @@ -72,7 +74,7 @@ func TestClientConn(t *testing.T) { stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) require.Error(t, err) require.True(t, cached) - require.Nil(t, stream) + require.Nil(t, stream.stream) } // TestClientUpdate checks the client's watcher update behavior @@ -92,10 +94,10 @@ func TestClientUpdate(t *testing.T) { s1, _, err := client.dial([]string{"s1"}, &proto.DialRequest{}) require.NoError(t, err) - require.NotNil(t, s1) + require.NotNil(t, s1.stream) s2, _, err := client.dial([]string{"s2"}, &proto.DialRequest{}) require.NoError(t, err) - require.NotNil(t, s2) + require.NotNil(t, s2.stream) // watcher finds one of the two servers err = client.updateConnections([]types.Server{def1}) @@ -105,7 +107,7 @@ func TestClientUpdate(t *testing.T) { sendMsg(t, s1) // stream is not broken across updates sendMsg(t, s2) // stream is not forcefully closed. ClientConn waits for a graceful shutdown before it closes. - s2.CloseSend() + s2.Close() // watcher finds two servers with one broken connection server2.Shutdown() @@ -128,8 +130,8 @@ func TestClientUpdate(t *testing.T) { require.NoError(t, err) require.NotNil(t, s3) - s1.CloseSend() - s3.CloseSend() + s1.Close() + s3.Close() } func TestCAChange(t *testing.T) { @@ -143,10 +145,11 @@ func TestCAChange(t *testing.T) { conn, err := client.connect("s1", server.config.Listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) - stream, err := client.startStream(conn) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) require.NoError(t, err) require.NotNil(t, stream) - stream.CloseSend() // rotate server ca require.NoError(t, server.Close()) @@ -158,7 +161,7 @@ func TestCAChange(t *testing.T) { conn, err = client.connect("s1", server.config.Listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) - stream, err = client.startStream(conn) + stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) require.Error(t, err) require.Nil(t, stream) @@ -175,10 +178,9 @@ func TestCAChange(t *testing.T) { conn, err = client.connect("s1", server.config.Listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) - stream, err = client.startStream(conn) + stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) require.NoError(t, err) require.NotNil(t, stream) - stream.CloseSend() } func TestBackupClient(t *testing.T) { diff --git a/lib/proxy/peer/helpers_test.go b/lib/proxy/peer/helpers_test.go index ce8eb17294946..ef846a1d95a11 100644 --- a/lib/proxy/peer/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -206,9 +206,7 @@ func setupClient(t *testing.T, clientCA, serverCA *tlsca.CertAuthority, role typ }) require.NoError(t, err) - t.Cleanup(func() { - client.Shutdown() - }) + t.Cleanup(func() { client.Stop() }) return client } @@ -264,11 +262,7 @@ func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthor return server, ts } -func sendMsg(t *testing.T, stream clientapi.ProxyService_DialNodeClient) { - err := stream.Send(&clientapi.Frame{ - Message: &clientapi.Frame_Data{ - Data: &clientapi.Data{Bytes: []byte("ping")}, - }, - }) +func sendMsg(t *testing.T, stream frameStream) { + err := stream.Send([]byte("ping")) require.NoError(t, err) } diff --git a/lib/proxy/peer/interceptor.go b/lib/proxy/peer/interceptor.go deleted file mode 100644 index 0cbfc6d4254da..0000000000000 --- a/lib/proxy/peer/interceptor.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package peer - -import ( - "context" - "sync" - - "github.com/gravitational/trace/trail" - "google.golang.org/grpc" -) - -// streamWrapper wraps around the embedded grpc.ClientStream -// and intercepts the RecvMsg method calls decreading the number of -// streams counter. -type streamWrapper struct { - grpc.ClientStream - wg *sync.WaitGroup - once sync.Once -} - -func (s *streamWrapper) CloseSend() error { - err := s.ClientStream.CloseSend() - s.decreaseCounter() - return err -} - -func (s *streamWrapper) SendMsg(m interface{}) error { - err := s.ClientStream.SendMsg(m) - if err != nil { - s.decreaseCounter() - } - return err -} - -func (s *streamWrapper) RecvMsg(m interface{}) error { - err := s.ClientStream.RecvMsg(m) - if err != nil { - s.decreaseCounter() - } - return err -} - -func (s *streamWrapper) decreaseCounter() { - s.once.Do(func() { - s.wg.Done() - }) -} - -// streamCounterInterceptor is gRPC client stream interceptor that -// counts the number of current open streams for the purpose of -// gracefully shutdown a draining gRPC client. -func streamCounterInterceptor(wg *sync.WaitGroup) grpc.StreamClientInterceptor { - return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - s, err := streamer(ctx, desc, cc, method, opts...) - if err != nil { - return nil, trail.ToGRPC(err) - } - wg.Add(1) - return &streamWrapper{ - ClientStream: s, - wg: wg, - }, nil - } -} diff --git a/lib/proxy/peer/server_test.go b/lib/proxy/peer/server_test.go index b8d634439e936..0f88b332c4864 100644 --- a/lib/proxy/peer/server_test.go +++ b/lib/proxy/peer/server_test.go @@ -41,7 +41,7 @@ func TestServerTLS(t *testing.T) { stream, _, err := client1.dial([]string{"s1"}, &proto.DialRequest{}) require.NoError(t, err) require.NotNil(t, stream) - stream.CloseSend() + stream.Close() // trusted certificates with incorrect server role. client2 := setupClient(t, ca1, ca1, types.RoleNode) @@ -59,5 +59,5 @@ func TestServerTLS(t *testing.T) { stream, _, err = client3.dial([]string{"s3"}, &proto.DialRequest{}) require.NoError(t, err) require.NotNil(t, stream) - stream.CloseSend() + stream.Close() } diff --git a/lib/service/service.go b/lib/service/service.go index 8f7899a02cb58..0d33fc47bb79b 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4865,7 +4865,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { warnOnErr(ctx, proxyServer.Shutdown(), logger) } if peerClient != nil { - peerClient.Shutdown() + peerClient.Shutdown(ctx) } if kubeServer != nil { warnOnErr(ctx, kubeServer.Shutdown(ctx), logger)