diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 9d9738323a643..a8efd7fe880d8 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -17,9 +17,11 @@ package proxy import ( "context" "crypto/tls" + "encoding/asn1" "io" "net" "strings" + "sync/atomic" "time" "github.com/gravitational/trace" @@ -64,9 +66,6 @@ type ClientConfig struct { ProxySSHAddress string // TLSRoutingEnabled indicates if the cluster is using TLS Routing. TLSRoutingEnabled bool - // ClusterName is the name of the Teleport cluster that the client - // will be connected to. - ClusterName string // TLSConfig contains the tls.Config required for mTLS connections. TLSConfig *tls.Config // UnaryInterceptors are optional [grpc.UnaryClientInterceptor] to apply @@ -91,6 +90,8 @@ type ClientConfig struct { clientCreds func() client.Credentials } +// CheckAndSetDefaults ensures required options are present and +// sets the default value of any that are omitted. func (c *ClientConfig) CheckAndSetDefaults() error { if c.ProxyWebAddress == "" { return trace.BadParameter("missing required parameter ProxyWebAddress") @@ -98,9 +99,6 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.ProxySSHAddress == "" { return trace.BadParameter("missing required parameter ProxySSHAddress") } - if c.ClusterName == "" { - return trace.BadParameter("missing required parameter ClusterName") - } if c.SSHDialer == nil { return trace.BadParameter("missing required parameter SSHDialer") } @@ -112,16 +110,27 @@ func (c *ClientConfig) CheckAndSetDefaults() error { } if c.TLSConfig != nil { - if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) { - tlsCfg := c.TLSConfig.Clone() - tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC) - c.TLSConfig = tlsCfg - } c.clientCreds = func() client.Credentials { return client.LoadTLS(c.TLSConfig.Clone()) } c.creds = func() credentials.TransportCredentials { - return credentials.NewTLS(c.TLSConfig.Clone()) + tlsCfg := c.TLSConfig.Clone() + if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) { + tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC) + } + + // This logic still appears to be necessary to force client to always send + // a certificate regardless of the server setting. Otherwise the client may pick + // not to send the client certificate by looking at certificate request. + if len(tlsCfg.Certificates) > 0 { + cert := tlsCfg.Certificates[0] + tlsCfg.Certificates = nil + tlsCfg.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &cert, nil + } + } + + return credentials.NewTLS(tlsCfg) } } else { c.clientCreds = func() client.Credentials { @@ -135,6 +144,8 @@ func (c *ClientConfig) CheckAndSetDefaults() error { return nil } +// insecureCredentials implements [client.Credentials] and is used by tests +// to connect to the Auth server without mTLS. type insecureCredentials struct{} func (mc insecureCredentials) Dialer(client.Config) (client.ContextDialer, error) { @@ -164,6 +175,9 @@ type Client struct { transport *transportv1.Client // sshClient is the established SSH connection to the Proxy. sshClient *tracessh.Client + // clusterName as determined by inspecting the certificate presented by + // the Proxy during the connection handshake. + clusterName *clusterName } // protocolProxySSHGRPC is TLS ALPN protocol value used to indicate gRPC @@ -205,16 +219,84 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { return nil, trace.NewAggregate(grpcErr, sshErr) } +// clusterName stores the name of the cluster +// in a protected manner which allows it to +// be set during handshakes with the server. +type clusterName struct { + name atomic.Pointer[string] +} + +func (c *clusterName) get() string { + name := c.name.Load() + if name != nil { + return *name + } + return "" +} + +func (c *clusterName) set(name string) { + c.name.CompareAndSwap(nil, &name) +} + +// clusterCredentials is a [credentials.TransportCredentials] implementation +// that obtains the name of the cluster being connected to from the certificate +// presented by the server. This allows the client to determine the cluster name when +// connecting via using jump hosts. +type clusterCredentials struct { + credentials.TransportCredentials + clusterName *clusterName +} + +var ( + // teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding + // origin teleport cluster name into certificates. + teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7} +) + +// ClientHandshake performs the handshake with the wrapped [credentials.TransportCredentials] and +// then inspects the provided cert for the [teleportClusterASN1ExtensionOID] to determine +// the cluster that the server belongs to. +func (c *clusterCredentials) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn, info, err := c.TransportCredentials.ClientHandshake(ctx, authority, conn) + if err != nil { + return conn, info, trace.Wrap(err) + } + + tlsInfo, ok := info.(credentials.TLSInfo) + if !ok { + return conn, info, nil + } + + certs := tlsInfo.State.PeerCertificates + if len(certs) == 0 { + return conn, info, nil + } + + clientCert := certs[0] + for _, attr := range clientCert.Subject.Names { + if attr.Type.Equal(teleportClusterASN1ExtensionOID) { + val, ok := attr.Value.(string) + if ok { + c.clusterName.set(val) + break + } + } + } + + return conn, info, nil +} + // newGRPCClient creates a Client that is connected via gRPC. func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error) { dialCtx, cancel := context.WithTimeout(ctx, cfg.DialTimeout) defer cancel() + c := &clusterName{} conn, err := grpc.DialContext( dialCtx, cfg.ProxySSHAddress, append(cfg.DialOpts, - grpc.WithTransportCredentials(cfg.creds()), + grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}), grpc.WithChainUnaryInterceptor( append(cfg.UnaryInterceptors, otelgrpc.UnaryClientInterceptor(), @@ -245,25 +327,71 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error } return &Client{ - cfg: cfg, - grpcConn: conn, - transport: transport, + cfg: cfg, + grpcConn: conn, + transport: transport, + clusterName: c, }, nil } +// teleportAuthority is the extension set by the server +// which contains the name of the cluster it is in. +const teleportAuthority = "x-teleport-authority" + +// clusterCallback is a [ssh.HostKeyCallback] that obtains the name +// of the cluster being connected to from the certificate presented by the server. +// This allows the client to determine the cluster name when using jump hosts. +func clusterCallback(c *clusterName, wrapped ssh.HostKeyCallback) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := wrapped(hostname, remote, key); err != nil { + return trace.Wrap(err) + } + + cert, ok := key.(*ssh.Certificate) + if !ok { + return nil + } + + clusterName, ok := cert.Permissions.Extensions[teleportAuthority] + if ok { + c.set(clusterName) + } + + return nil + } +} + // newSSHClient creates a Client that is connected via SSH. func newSSHClient(ctx context.Context, cfg *ClientConfig) (*Client, error) { - clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, cfg.SSHConfig) + c := &clusterName{} + clientCfg := &ssh.ClientConfig{ + User: cfg.SSHConfig.User, + Auth: cfg.SSHConfig.Auth, + HostKeyCallback: clusterCallback(c, cfg.SSHConfig.HostKeyCallback), + BannerCallback: cfg.SSHConfig.BannerCallback, + ClientVersion: cfg.SSHConfig.ClientVersion, + HostKeyAlgorithms: cfg.SSHConfig.HostKeyAlgorithms, + Timeout: cfg.SSHConfig.Timeout, + } + + clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, clientCfg) if err != nil { return nil, trace.Wrap(err) } return &Client{ - cfg: cfg, - sshClient: clt, + cfg: cfg, + sshClient: clt, + clusterName: c, }, nil } +// ClusterName returns the name of the cluster that the +// connected Proxy is a member of. +func (c *Client) ClusterName() string { + return c.clusterName.get() +} + // Close attempts to close both the gRPC and SSH connections. func (c *Client) Close() error { var errs []error @@ -486,7 +614,7 @@ func dialSSH(ctx context.Context, clt *tracessh.Client, proxyAddress, targetAddr // read the stderr output from the failed SSH session and append // it to the end of our own message: serverErrorMsg, _ := io.ReadAll(sessionError) - return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %v. %v", targetAddress, serverErrorMsg, err) + return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %s. %v", targetAddress, serverErrorMsg, err) } return conn, nil diff --git a/api/client/proxy/client_test.go b/api/client/proxy/client_test.go index d579c905ca3bf..ebee00737f62a 100644 --- a/api/client/proxy/client_test.go +++ b/api/client/proxy/client_test.go @@ -18,7 +18,10 @@ import ( "context" "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" "encoding/pem" "errors" "fmt" @@ -36,6 +39,7 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/testing/protocmp" @@ -375,7 +379,6 @@ func (f *fakeProxy) clientConfig(t *testing.T) ClientConfig { return ClientConfig{ ProxyWebAddress: "127.0.0.1", ProxySSHAddress: "127.0.0.1", - ClusterName: "test", SSHDialer: SSHDialerFunc(func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { conn, chans, reqs, err := f.fakeSSHServer.newClientConn() if err != nil { @@ -884,3 +887,193 @@ func TestClient_SSHConfig(t *testing.T) { require.Equal(t, user, sshConfig.User) require.Empty(t, cmp.Diff(cfg.SSHConfig, sshConfig, cmpopts.IgnoreFields(ssh.ClientConfig{}, "User", "Auth", "HostKeyCallback"))) } + +type fakeTransportCredentials struct { + credentials.TransportCredentials + info credentials.AuthInfo + err error +} + +type fakeAuthInfo struct{} + +func (f fakeAuthInfo) AuthType() string { + return "test" +} + +func (t fakeTransportCredentials) ClientHandshake(ctx context.Context, addr string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return conn, t.info, t.err +} + +func TestClusterCredentials(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + expectedClusterName string + credentials fakeTransportCredentials + errAssertion require.ErrorAssertionFunc + }{ + { + name: "handshake error", + credentials: fakeTransportCredentials{err: context.Canceled}, + errAssertion: require.Error, + }, + { + name: "no tls auth info", + credentials: fakeTransportCredentials{info: fakeAuthInfo{}}, + errAssertion: require.NoError, + }, + { + name: "no server cert", + credentials: fakeTransportCredentials{info: credentials.TLSInfo{}}, + errAssertion: require.NoError, + }, + { + name: "no cluster oid set", + credentials: fakeTransportCredentials{info: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 1}, + }, + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 1}, + }, + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 2}, + }, + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 2}, + }, + }, + }, + }, + }, + }, + }}, + errAssertion: require.NoError, + }, { + name: "cluster name presented", + expectedClusterName: "test-cluster", + credentials: fakeTransportCredentials{info: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 1}, + }, + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 2}, + }, + { + Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 2}, + }, + { + Type: teleportClusterASN1ExtensionOID, + Value: "test-cluster", + }, + }, + }, + }, + }, + }, + }}, + errAssertion: require.NoError, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + c := &clusterName{} + creds := clusterCredentials{TransportCredentials: test.credentials, clusterName: c} + _, _, err := creds.ClientHandshake(context.Background(), "127.0.0.1", nil) + test.errAssertion(t, err) + require.Equal(t, test.expectedClusterName, c.get()) + }) + } +} + +type fakePublicKey struct{} + +func (f fakePublicKey) Type() string { + return "test" +} + +func (f fakePublicKey) Marshal() []byte { + return nil +} + +func (f fakePublicKey) Verify(data []byte, sig *ssh.Signature) error { + return trace.NotImplemented("") +} + +func TestClusterCallback(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + hostKeyCB ssh.HostKeyCallback + publicKey ssh.PublicKey + expectedClusterName string + errAssertion require.ErrorAssertionFunc + }{ + { + name: "handshake failure", + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return context.Canceled + }, + errAssertion: require.Error, + }, + { + name: "invalid certificate", + publicKey: fakePublicKey{}, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + { + name: "no authority present", + publicKey: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + + { + name: "cluster name presented", + expectedClusterName: "test-cluster", + publicKey: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleportAuthority: "test-cluster", + }, + }, + }, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + c := &clusterName{} + err := clusterCallback(c, test.hostKeyCB)("test", addr("127.0.0.1"), test.publicKey) + test.errAssertion(t, err) + require.Equal(t, test.expectedClusterName, c.get()) + + }) + } +}