Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import (
"github.com/gravitational/teleport/api/defaults"
transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/observability/tracing"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/utils/grpc/interceptors"
)
Expand Down Expand Up @@ -252,11 +251,9 @@ type clusterCredentials struct {
clusterName *clusterName
}

var (
// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding
// origin teleport cluster name into certificates.
teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7}
)
// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding
// origin teleport cluster name into certificates.
var teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7}

// ClientHandshake performs the handshake with the wrapped [credentials.TransportCredentials] and
// then inspects the provided cert for the [teleportClusterASN1ExtensionOID] to determine
Expand Down Expand Up @@ -492,18 +489,7 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config
CircuitBreakerConfig: breaker.NoopBreakerConfig(),
DialInBackground: true,
Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) {
// Don't dial if the context has timed out.
select {
case <-dialCtx.Done():
return nil, dialCtx.Err()
default:
}

// Intentionally not using the dial context because it is only valid
// for the lifetime of the dial. Using it causes the stream to be terminated
// immediately after the dial completes.
connContext := tracing.WithPropagationContext(context.Background(), tracing.PropagationContextFromContext(dialCtx))
conn, err := c.transport.DialCluster(connContext, cluster, nil)
conn, err := c.transport.DialCluster(dialCtx, cluster, nil)
return conn, trace.Wrap(err)
}),
DialOpts: c.cfg.DialOpts,
Expand Down
46 changes: 43 additions & 3 deletions api/client/proxy/transport/transportv1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"google.golang.org/grpc/peer"

transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1"
"github.com/gravitational/teleport/api/internal/context121"
streamutils "github.com/gravitational/teleport/api/utils/grpc/stream"
)

Expand Down Expand Up @@ -59,22 +60,57 @@ func (c *Client) ClusterDetails(ctx context.Context) (*transportv1pb.ClusterDeta
// DialCluster establishes a connection to the provided cluster. The provided
// src address will be used as the LocalAddr of the returned [net.Conn].
func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr) (net.Conn, error) {
stream, err := c.clt.ProxyCluster(ctx)
// we do this rather than using context.Background to inherit any OTEL data
// from the dial context
connCtx, cancel := context.WithCancel(context121.WithoutCancel(ctx))

// a rough replacement for `stop := context.AfterFunc(ctx, cancel)` in 1.20
stopped := make(chan bool, 1)
stopper := make(chan struct{})
var stopOnce sync.Once
// there's little reason to optimize for ctx.Done() == nil since this is
// always called with a cancellable context in all current code
go func() {
select {
case <-ctx.Done():
close(stopped)
cancel()
case <-stopper:
stopped <- true
close(stopped)
}
}()
stop := func() bool {
stopOnce.Do(func() { close(stopper) })
return <-stopped
}
defer stop()

stream, err := c.clt.ProxyCluster(connCtx)
if err != nil {
cancel()
return nil, trace.Wrap(err, "unable to establish proxy stream")
}

if err := stream.Send(&transportv1pb.ProxyClusterRequest{Cluster: cluster}); err != nil {
cancel()
return nil, trace.Wrap(err, "failed to send cluster request")
}

streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream})
if !stop() {
cancel()
return nil, trace.Wrap(connCtx.Err(), "unable to establish proxy stream")
}

streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream, cancel: cancel})
if err != nil {
cancel()
return nil, trace.Wrap(err, "unable to create stream reader")
}

p, ok := peer.FromContext(stream.Context())
if !ok {
streamRW.Close()
return nil, trace.BadParameter("unable to retrieve peer information")
}

Expand All @@ -85,6 +121,7 @@ func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr)
// for a [transportv1pb.TransportService_ProxyClusterClient].
type clusterStream struct {
stream transportv1pb.TransportService_ProxyClusterClient
cancel context.CancelFunc
}

func (c clusterStream) Recv() ([]byte, error) {
Expand All @@ -105,7 +142,10 @@ func (c clusterStream) Send(frame []byte) error {
}

func (c clusterStream) Close() error {
return trace.Wrap(c.stream.CloseSend())
if c.cancel != nil {
c.cancel()
}
return nil
}

// DialHost establishes a connection to the instance in the provided cluster that matches
Expand Down
2 changes: 0 additions & 2 deletions api/utils/grpc/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ func (c *ReadWriter) Write(b []byte) (int, error) {
// Close cleans up resources used by the stream.
func (c *ReadWriter) Close() error {
if cs, ok := c.source.(io.Closer); ok {
c.wLock.Lock()
defer c.wLock.Unlock()
return trace.Wrap(cs.Close())
}

Expand Down
Loading