Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"github.com/gravitational/teleport/api/defaults"
transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/utils/grpc/interceptors"
)

Expand Down Expand Up @@ -228,11 +227,9 @@ type clusterCredentials struct {
clusterName *clusterName
}

var (
// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding
// origin teleport cluster name into certificates.
teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7}
)
// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding
// origin teleport cluster name into certificates.
var teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7}

// ClientHandshake performs the handshake with the wrapped [credentials.TransportCredentials] and
// then inspects the provided cert for the [teleportClusterASN1ExtensionOID] to determine
Expand Down Expand Up @@ -399,18 +396,7 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) (client.Confi
CircuitBreakerConfig: breaker.NoopBreakerConfig(),
DialInBackground: true,
Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) {
// Don't dial if the context has timed out.
select {
case <-dialCtx.Done():
return nil, dialCtx.Err()
default:
}

// Intentionally not using the dial context because it is only valid
// for the lifetime of the dial. Using it causes the stream to be terminated
// immediately after the dial completes.
connContext := tracing.WithPropagationContext(context.Background(), tracing.PropagationContextFromContext(dialCtx))
conn, err := c.transport.DialCluster(connContext, cluster, nil)
conn, err := c.transport.DialCluster(dialCtx, cluster, nil)
return conn, trace.Wrap(err)
}),
DialOpts: c.cfg.DialOpts,
Expand Down
25 changes: 22 additions & 3 deletions api/client/proxy/transport/transportv1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,37 @@ func (c *Client) ClusterDetails(ctx context.Context) (*transportv1pb.ClusterDeta
// DialCluster establishes a connection to the provided cluster. The provided
// src address will be used as the LocalAddr of the returned [net.Conn].
func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr) (net.Conn, error) {
stream, err := c.clt.ProxyCluster(ctx)
// we do this rather than using context.Background to inherit any OTEL data
// from the dial context
connCtx, cancel := context.WithCancel(context.WithoutCancel(ctx))
stop := context.AfterFunc(ctx, cancel)
defer stop()

stream, err := c.clt.ProxyCluster(connCtx)
if err != nil {
cancel()
return nil, trace.Wrap(err, "unable to establish proxy stream")
}

if err := stream.Send(&transportv1pb.ProxyClusterRequest{Cluster: cluster}); err != nil {
cancel()
return nil, trace.Wrap(err, "failed to send cluster request")
}

streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream})
if !stop() {
cancel()
return nil, trace.Wrap(connCtx.Err(), "unable to establish proxy stream")
}

streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream, cancel: cancel})
if err != nil {
cancel()
return nil, trace.Wrap(err, "unable to create stream reader")
}

p, ok := peer.FromContext(stream.Context())
if !ok {
streamRW.Close()
return nil, trace.BadParameter("unable to retrieve peer information")
}

Expand All @@ -85,6 +100,7 @@ func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr)
// for a [transportv1pb.TransportService_ProxyClusterClient].
type clusterStream struct {
stream transportv1pb.TransportService_ProxyClusterClient
cancel context.CancelFunc
}

func (c clusterStream) Recv() ([]byte, error) {
Expand All @@ -105,7 +121,10 @@ func (c clusterStream) Send(frame []byte) error {
}

func (c clusterStream) Close() error {
return trace.Wrap(c.stream.CloseSend())
if c.cancel != nil {
c.cancel()
}
return nil
}

// DialHost establishes a connection to the instance in the provided cluster that matches
Expand Down
2 changes: 0 additions & 2 deletions api/utils/grpc/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ func (c *ReadWriter) Write(b []byte) (int, error) {
// Close cleans up resources used by the stream.
func (c *ReadWriter) Close() error {
if cs, ok := c.source.(io.Closer); ok {
c.wLock.Lock()
defer c.wLock.Unlock()
return trace.Wrap(cs.Close())
}

Expand Down
159 changes: 85 additions & 74 deletions lib/proxy/peer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,57 @@ func (c *ClientConfig) checkAndSetDefaults() error {
// clientConn hold info about a dialed grpc connection
type clientConn struct {
*grpc.ClientConn
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup

id string
addr string

// if closing is set, count is not allowed to increase from zero; upon
// reaching zero, cond should be broadcast
mu sync.Mutex
cond sync.Cond
closing bool
count int
}

func (c *clientConn) maybeAcquire() (release func()) {
c.mu.Lock()
defer c.mu.Unlock()

if c.closing && c.count < 1 {
return nil
}
c.count++

return sync.OnceFunc(func() {
c.mu.Lock()
defer c.mu.Unlock()
c.count--
if c.count == 0 {
c.cond.Broadcast()
}
})
}

// Shutdown closes the clientConn after all connections through it are closed,
// or after the context is done.
func (c *clientConn) Shutdown(ctx context.Context) {
defer c.Close()

c.mu.Lock()
defer c.mu.Unlock()

c.closing = true
if c.count == 0 {
return
}

if c.cond.L == nil {
c.cond.L = &c.mu
}
defer context.AfterFunc(ctx, c.cond.Broadcast)()
for c.count > 0 && ctx.Err() == nil {
c.cond.Wait()
}
}

// Client is a peer proxy service client using grpc and tls.
Expand Down Expand Up @@ -335,7 +380,7 @@ func (c *Client) updateConnections(proxies []types.Server) error {

for _, id := range toDelete {
if conn, ok := c.conns[id]; ok {
go c.shutdownConn(conn)
go conn.Shutdown(c.ctx)
}
}
c.conns = toKeep
Expand Down Expand Up @@ -367,8 +412,9 @@ func (c *Client) DialNode(
return nil, trace.ConnectionProblem(err, "error dialing peer proxies %s: %v", proxyIDs, err)
}

streamRW, err := streamutils.NewReadWriter(frameStream{stream: stream})
streamRW, err := streamutils.NewReadWriter(stream)
if err != nil {
_ = stream.Close()
return nil, trace.Wrap(err)
}

Expand All @@ -385,6 +431,7 @@ type stream interface {
// frameStream implements [streamutils.Source].
type frameStream struct {
stream stream
cancel context.CancelFunc
}

func (s frameStream) Send(p []byte) error {
Expand All @@ -405,15 +452,14 @@ func (s frameStream) Recv() ([]byte, error) {
}

func (s frameStream) Close() error {
if cs, ok := s.stream.(grpc.ClientStream); ok {
return trace.Wrap(cs.CloseSend())
if s.cancel != nil {
s.cancel()
}

return nil
}

// Shutdown gracefully shuts down all existing client connections.
func (c *Client) Shutdown() {
func (c *Client) Shutdown(ctx context.Context) {
c.Lock()
defer c.Unlock()

Expand All @@ -422,23 +468,7 @@ func (c *Client) Shutdown() {
wg.Add(1)
go func(conn *clientConn) {
defer wg.Done()

timeoutCtx, cancel := context.WithTimeout(context.Background(), c.config.GracefulShutdownTimeout)
defer cancel()

go func() {
if err := c.shutdownConn(conn); err != nil {
c.config.Log.Infof("proxy peer connection %+v graceful shutdown error: %+v", conn.id, err)
}
}()

select {
case <-conn.ctx.Done():
case <-timeoutCtx.Done():
if err := c.stopConn(conn); err != nil {
c.config.Log.Infof("proxy peer connection %+v close error: %+v", conn.id, err)
}
}
conn.Shutdown(ctx)
}(conn)
}
wg.Wait()
Expand All @@ -452,7 +482,7 @@ func (c *Client) Stop() error {

var errs []error
for _, conn := range c.conns {
if err := c.stopConn(conn); err != nil {
if err := conn.Close(); err != nil {
errs = append(errs, err)
}
}
Expand All @@ -466,67 +496,67 @@ func (c *Client) GetConnectionsCount() int {
return len(c.conns)
}

// shutdownConn gracefully shuts down a clientConn
// by waiting for open streams to finish.
func (c *Client) shutdownConn(conn *clientConn) error {
conn.wg.Wait() // wait for streams to gracefully end
conn.cancel()
return conn.Close()
}

// stopConn immediately closes a clientConn
func (c *Client) stopConn(conn *clientConn) error {
conn.cancel()
return conn.Close()
}

// dial opens a new stream to one of the supplied proxy ids.
// it tries to find an existing grpc.ClientConn or initializes a new rpc
// to one of the proxies otherwise.
// The boolean returned in the second argument is intended for testing purposes,
// to indicates whether the connection was cached or newly established.
func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (clientapi.ProxyService_DialNodeClient, bool, error) {
func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (frameStream, bool, error) {
conns, existing, err := c.getConnections(proxyIDs)
if err != nil {
return nil, existing, trace.Wrap(err)
return frameStream{}, existing, trace.Wrap(err)
}

var errs []error
for _, conn := range conns {
stream, err := c.startStream(conn)
release := conn.maybeAcquire()
if release == nil {
c.metrics.reportTunnelError(errorProxyPeerTunnelRPC)
errs = append(errs, trace.ConnectionProblem(nil, "error starting stream: connection is shutting down"))
continue
}

ctx, cancel := context.WithCancel(context.Background())
context.AfterFunc(ctx, release)

stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx)
if err != nil {
cancel()
c.metrics.reportTunnelError(errorProxyPeerTunnelRPC)
c.config.Log.Debugf("Error opening tunnel rpc to proxy %+v at %+v", conn.id, conn.addr)
errs = append(errs, trace.ConnectionProblem(err, "error starting stream: %v", err))
continue
}

err = stream.Send(&clientapi.Frame{
Message: &clientapi.Frame_DialRequest{
DialRequest: dialRequest,
},
})
if err != nil {
cancel()
errs = append(errs, trace.ConnectionProblem(err, "error sending dial frame: %v", err))
continue
}
msg, err := stream.Recv()
if err != nil {
cancel()
errs = append(errs, trace.ConnectionProblem(err, "error receiving dial response: %v", err))
continue
}
if msg.GetConnectionEstablished() == nil {
err := stream.CloseSend()
if err != nil {
c.config.Log.Debugf("error closing stream: %w", err)
}
cancel()
errs = append(errs, trace.ConnectionProblem(nil, "received malformed connection established frame"))
continue
}

return stream, existing, nil
return frameStream{
stream: stream,
cancel: cancel,
}, existing, nil
}

return nil, existing, trace.NewAggregate(errs...)
return frameStream{}, existing, trace.NewAggregate(errs...)
}

// getConnections returns connections to the supplied proxy ids.
Expand Down Expand Up @@ -611,17 +641,13 @@ func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) {
return nil, trace.Wrap(err, "Error updating client tls config")
}

connCtx, cancel := context.WithCancel(c.ctx)
wg := new(sync.WaitGroup)

expectedPeer := auth.HostFQDN(peerID, c.config.ClusterName)

conn, err := grpc.DialContext(
connCtx,
conn, err := grpc.Dial(
peerAddr,
grpc.WithTransportCredentials(newClientCredentials(expectedPeer, peerAddr, c.config.Log, credentials.NewTLS(tlsConfig))),
grpc.WithStatsHandler(newStatsHandler(c.reporter)),
grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor, streamCounterInterceptor(wg)),
grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: peerKeepAlive,
Timeout: peerTimeout,
Expand All @@ -630,28 +656,13 @@ func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) {
grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`),
)
if err != nil {
cancel()
return nil, trace.Wrap(err, "Error dialing proxy %q", peerID)
}

return &clientConn{
ClientConn: conn,
ctx: connCtx,
cancel: cancel,
wg: wg,
id: peerID,
addr: peerAddr,
}, nil
}

// startStream opens a new stream to the provided connection.
func (c *Client) startStream(conn *clientConn) (clientapi.ProxyService_DialNodeClient, error) {
client := clientapi.NewProxyServiceClient(conn.ClientConn)

stream, err := client.DialNode(conn.ctx)
if err != nil {
return nil, trace.Wrap(err, "Error opening stream to proxy %+v", conn.id)
}

return stream, nil
id: peerID,
addr: peerAddr,
}, nil
}
Loading