From eaa4ffcf9cb3b1084e3a0489bc867663bf4f9f7f Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Wed, 15 Mar 2023 09:41:56 -0400 Subject: [PATCH 1/2] Add a dedicated client to communicate with the Proxy SSH server (#22629) A new `api/client/proxy/Client` has been added to interact with the SSH and gRPC servers that the Proxy serves on its SSH port. The client will first try connecting to the gRPC server and if that fails it will fall back to the SSH server. Much of the SSH functionality mimics the existing behavior of the `ProxyClient` in `lib/client`. This is the first part of phasing out that client in favor of the new client. There will be a follow up PR that migrates `lib/client` to make use of the new client instead. Part of #19812 --- api/client/contextdialer.go | 4 +- api/client/proxy/client.go | 329 +++++++++++++ api/client/proxy/client_test.go | 543 ++++++++++++++++++++++ api/client/proxy/session_conn.go | 102 ++++ api/client/webclient/webclient.go | 3 +- api/{client/proxy => utils}/proxy.go | 2 +- api/{client/proxy => utils}/proxy_test.go | 2 +- lib/client/https_client.go | 3 +- lib/utils/proxy/proxy.go | 4 +- tool/tsh/proxy.go | 4 +- tool/tsh/resolve_default_addr.go | 4 +- 11 files changed, 986 insertions(+), 14 deletions(-) create mode 100644 api/client/proxy/client.go create mode 100644 api/client/proxy/client_test.go create mode 100644 api/client/proxy/session_conn.go rename api/{client/proxy => utils}/proxy.go (99%) rename api/{client/proxy => utils}/proxy_test.go (99%) diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index ba9e07813c20c..bb53ec64500f1 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -26,11 +26,11 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" ) @@ -81,7 +81,7 @@ func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc { func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration) ContextDialer { return tracedDialer(ctx, func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := newDirectDialer(keepAlivePeriod, dialTimeout) - if proxyURL := proxy.GetProxyURL(addr); proxyURL != nil { + if proxyURL := utils.GetProxyURL(addr); proxyURL != nil { return DialProxyWithDialer(ctx, proxyURL, addr, dialer) } return dialer.DialContext(ctx, network, addr) diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go new file mode 100644 index 0000000000000..9d2b97cf13682 --- /dev/null +++ b/api/client/proxy/client.go @@ -0,0 +1,329 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "crypto/tls" + "io" + "net" + "strings" + "time" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/gravitational/teleport/api/breaker" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/defaults" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" +) + +// SSHDialer provides a mechanism to create a ssh client. +type SSHDialer interface { + // Dial establishes a client connection to an SSH server. + Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) +} + +// SSHDialerFunc implements SSHDialer +type SSHDialerFunc func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) + +// Dial calls f(ctx, network, addr, config). +func (f SSHDialerFunc) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { + return f(ctx, network, addr, config) +} + +// ClientConfig contains configuration needed for a Client +// to be able to connect to the cluster. +type ClientConfig struct { + // ProxyWebAddress is the address of the Proxy Web server. + ProxyWebAddress string + // ProxySSHAddress is the address of the Proxy SSH server. + ProxySSHAddress string + // TLSRoutingEnabled indicates if the cluster is using TLS Routing. + TLSRoutingEnabled bool + // ClusterName is the name of the Teleport cluster that the client + // will be connected to. + ClusterName string + // TLSConfig contains the tls.Config required for mTLS connections. + TLSConfig *tls.Config + // SSHDialer allows callers to control how a [tracessh.Client] is created. + SSHDialer SSHDialer + // SSHConfig is the [ssh.ClientConfig] used to connect to the Proxy SSH server. + SSHConfig *ssh.ClientConfig + // DialTimeout defines how long to attempt dialing before timing out. + DialTimeout time.Duration + + // The client credentials to use when establishing the connection to auth. + clientCreds func() client.Credentials +} + +func (c *ClientConfig) CheckAndSetDefaults() error { + if c.ProxyWebAddress == "" { + return trace.BadParameter("missing required parameter ProxyWebAddress") + } + if c.ProxySSHAddress == "" { + return trace.BadParameter("missing required parameter ProxySSHAddress") + } + if c.ClusterName == "" { + return trace.BadParameter("missing required parameter ClusterName") + } + if c.SSHDialer == nil { + return trace.BadParameter("missing required parameter SSHDialer") + } + if c.SSHConfig == nil { + return trace.BadParameter("missing required parameter SSHConfig") + } + if c.DialTimeout <= 0 { + c.DialTimeout = defaults.DefaultIOTimeout + } + + if c.TLSConfig != nil { + c.clientCreds = func() client.Credentials { + return client.LoadTLS(c.TLSConfig.Clone()) + } + } else { + c.clientCreds = func() client.Credentials { + return insecureCredentials{} + } + } + + return nil +} + +type insecureCredentials struct{} + +func (mc insecureCredentials) Dialer(client.Config) (client.ContextDialer, error) { + return nil, trace.NotImplemented("no dialer") +} + +func (mc insecureCredentials) TLSConfig() (*tls.Config, error) { + return nil, nil +} + +func (mc insecureCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { + return nil, trace.NotImplemented("no ssh config") +} + +// Client is a client to the Teleport Proxy SSH server on behalf of a user. +type Client struct { + // cfg are the user provided configuration parameters required to + // connect and interact with the Proxy. + cfg *ClientConfig + // sshClient is the established SSH connection to the Proxy. + sshClient *tracessh.Client +} + +// NewClient creates a new Client that attempts to connect to the SSH +// server being served by the Proxy SSH port by default. +func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + clt, err := newSSHClient(ctx, &cfg) + return clt, trace.Wrap(err) +} + +// newSSHClient creates a Client that is connected via SSH. +func newSSHClient(ctx context.Context, cfg *ClientConfig) (*Client, error) { + clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, cfg.SSHConfig) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Client{ + cfg: cfg, + sshClient: clt, + }, nil +} + +// Close attempts to close the SSH connections. +func (c *Client) Close() error { + return trace.Wrap(c.sshClient.Close()) +} + +// SSHConfig returns the [ssh.ClientConfig] for the provided user which +// should be used when creating a [tracessh.Client] with the returned +// [net.Conn] from [Client.DialHost]. +func (c *Client) SSHConfig(user string) *ssh.ClientConfig { + return &ssh.ClientConfig{ + Config: c.cfg.SSHConfig.Config, + User: user, + Auth: c.cfg.SSHConfig.Auth, + HostKeyCallback: c.cfg.SSHConfig.HostKeyCallback, + BannerCallback: c.cfg.SSHConfig.BannerCallback, + ClientVersion: c.cfg.SSHConfig.ClientVersion, + HostKeyAlgorithms: c.cfg.SSHConfig.HostKeyAlgorithms, + Timeout: c.cfg.SSHConfig.Timeout, + } +} + +// ClusterDetails provide cluster configuration +// details as known by the connected Proxy. +type ClusterDetails struct { + // FIPS dictates whether FIPS mode is enabled. + FIPS bool +} + +// ClientConfig returns a [client.Config] that may be used to connect to the +// Auth server in the provided cluster via [client.New] or similar. The [client.Config] +// returned will have the correct credentials and dialer set based on the ClientConfig +// that was provided to create this Client. +func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config { + if c.cfg.TLSRoutingEnabled { + return client.Config{ + Context: ctx, + Addrs: []string{c.cfg.ProxyWebAddress}, + Credentials: []client.Credentials{c.cfg.clientCreds()}, + ALPNSNIAuthDialClusterName: cluster, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + } + } + + return client.Config{ + Context: ctx, + Credentials: []client.Credentials{c.cfg.clientCreds()}, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + DialInBackground: true, + Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) { + // Don't dial if the context has timed out. + select { + case <-dialCtx.Done(): + return nil, dialCtx.Err() + default: + } + + conn, err := dialSSH(dialCtx, c.sshClient, c.cfg.ProxySSHAddress, "@"+cluster, nil) + return conn, trace.Wrap(err) + }), + } +} + +// DialHost establishes a connection to the `target` in cluster named `cluster`. If a keyring +// is provided it will only be forwarded if proxy recording mode is enabled in the cluster. +func (c *Client) DialHost(ctx context.Context, target, cluster string, keyring agent.ExtendedAgent) (net.Conn, ClusterDetails, error) { + conn, details, err := c.dialHostSSH(ctx, target, cluster, keyring) + return conn, details, trace.Wrap(err) +} + +// dialHostSSH connects to the target via SSH. To match backwards compatibility the +// cluster details are retrieved from the Proxy SSH server via a clusterDetailsRequest +// request to determine if the keyring should be forwarded. +func (c *Client) dialHostSSH(ctx context.Context, target, cluster string, keyring agent.ExtendedAgent) (net.Conn, ClusterDetails, error) { + details, err := c.clusterDetailsSSH(ctx) + if err != nil { + return nil, ClusterDetails{FIPS: details.FIPSEnabled}, trace.Wrap(err) + } + + // Prevent forwarding the keychain if the proxy is + // not doing the recording. + if !details.RecordingProxy { + keyring = nil + } + + conn, err := dialSSH(ctx, c.sshClient, c.cfg.ProxySSHAddress, target+"@"+cluster, keyring) + return conn, ClusterDetails{FIPS: details.FIPSEnabled}, trace.Wrap(err) +} + +// ClusterDetails retrieves cluster information as seen by the Proxy. +func (c *Client) ClusterDetails(ctx context.Context) (ClusterDetails, error) { + details, err := c.clusterDetailsSSH(ctx) + return ClusterDetails{FIPS: details.FIPSEnabled}, trace.Wrap(err) +} + +// sshDetails is the response from a clusterDetailsRequest. +type sshDetails struct { + RecordingProxy bool + FIPSEnabled bool +} + +const clusterDetailsRequest = "cluster-details@goteleport.com" + +// clusterDetailsSSH retrieves the cluster details via a clusterDetailsRequest. +func (c *Client) clusterDetailsSSH(ctx context.Context) (sshDetails, error) { + ok, resp, err := c.sshClient.SendRequest(ctx, clusterDetailsRequest, true, nil) + if err != nil { + return sshDetails{}, trace.Wrap(err) + } + + if !ok { + return sshDetails{}, trace.ConnectionProblem(nil, "failed to get cluster details") + } + + var details sshDetails + if err := ssh.Unmarshal(resp, &details); err != nil { + return sshDetails{}, trace.Wrap(err) + } + + return details, trace.Wrap(err) +} + +// dialSSH creates a SSH session to the target address and proxies a [net.Conn] +// over the standard input and output of the session. +func dialSSH(ctx context.Context, clt *tracessh.Client, proxyAddress, targetAddress string, keyring agent.ExtendedAgent) (_ net.Conn, err error) { + session, err := clt.NewSession(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + defer func() { + if err != nil { + _ = session.Close() + } + }() + + conn, err := newSessionConn(session, proxyAddress, targetAddress) + if err != nil { + return nil, trace.Wrap(err) + } + + defer func() { + if err != nil { + _ = conn.Close() + } + }() + + sessionError, err := session.StderrPipe() + if err != nil { + return nil, trace.Wrap(err) + } + + // If a keyring was provided then set up agent forwarding. + if keyring != nil { + // Add a handler to receive requests on the auth-agent@openssh.com channel. If there is + // already a handler it's safe to ignore the error because we only need one active handler + // to process requests. + err = agent.ForwardToAgent(clt.Client, keyring) + if err != nil && !strings.Contains(err.Error(), "agent: already have handler for") { + return nil, trace.Wrap(err) + } + + err = agent.RequestAgentForwarding(session.Session) + if err != nil { + return nil, trace.Wrap(err) + } + } + + if err := session.RequestSubsystem(ctx, "proxy:"+targetAddress); err != nil { + // read the stderr output from the failed SSH session and append + // it to the end of our own message: + serverErrorMsg, _ := io.ReadAll(sessionError) + return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %v. %v", targetAddress, serverErrorMsg, err) + } + + return conn, nil +} diff --git a/api/client/proxy/client_test.go b/api/client/proxy/client_test.go new file mode 100644 index 0000000000000..9047faea17325 --- /dev/null +++ b/api/client/proxy/client_test.go @@ -0,0 +1,543 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "io" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/testing/protocmp" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/utils/sshutils" +) + +type fakeSSHServer struct { + listener net.Listener + cfg fakeSSHServerConfig +} + +func (s *fakeSSHServer) run() { + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + + go func() { + sconn, chans, reqs, err := ssh.NewServerConn(conn, s.cfg.config) + if err != nil { + return + } + s.cfg.handler(sconn, chans, reqs) + }() + } +} + +func (s *fakeSSHServer) Stop() error { + return s.listener.Close() +} + +func generateSigner(t *testing.T) ssh.Signer { + private, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(private), + } + + privatePEM := pem.EncodeToMemory(block) + signer, err := ssh.ParsePrivateKey(privatePEM) + require.NoError(t, err) + return signer +} + +func (s *fakeSSHServer) clientConfig() *ssh.ClientConfig { + return &ssh.ClientConfig{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cfg.cSigner)}, + HostKeyCallback: ssh.FixedHostKey(s.cfg.hSigner.PublicKey()), + } +} + +func (s *fakeSSHServer) newClientConn() (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + conn, err := net.Dial("tcp", s.listener.Addr().String()) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + sconn, nc, r, err := ssh.NewClientConn(conn, "", s.clientConfig()) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sconn, nc, r, nil +} + +type sshHandler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request) + +type fakeSSHServerConfig struct { + config *ssh.ServerConfig + handler sshHandler + cSigner ssh.Signer + hSigner ssh.Signer +} + +func discardHandler(conn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + defer func() { _ = conn.Close() }() + + go ssh.DiscardRequests(reqs) + + for ch := range chans { + _ = ch.Reject(ssh.Prohibited, "discard") + } +} + +func proxySubsystemHandler(details sshDetails, handleConn func(conn *ssh.ServerConn, ch ssh.Channel)) sshHandler { + return func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + defer func() { _ = conn.Close() }() + + go func() { + for req := range requests { + if req.Type == clusterDetailsRequest { + _ = req.Reply(true, ssh.Marshal(details)) + } + } + }() + + for nch := range channels { + if nch.ChannelType() != "session" { + _ = nch.Reject(ssh.UnknownChannelType, "unknown channel") + continue + } + + ch, reqs, err := nch.Accept() + if err != nil { + return + } + + go func() { + defer func() { _ = ch.Close() }() + + for req := range reqs { + ok := req.Type == "subsystem" + + if req.WantReply { + _ = req.Reply(ok, nil) + } + + if !ok { + continue + } + + handleConn(conn, ch) + } + }() + } + } +} + +func echoHandler(details sshDetails) sshHandler { + return proxySubsystemHandler(details, func(conn *ssh.ServerConn, ch ssh.Channel) { + _, _ = io.Copy(ch, ch) + }) +} + +func authHandler(t *testing.T) sshHandler { + return proxySubsystemHandler(sshDetails{}, func(conn *ssh.ServerConn, ch ssh.Channel) { + auth := newFakeAuthServer(t, sshutils.NewChConn(conn, ch)) + t.Cleanup(auth.Stop) + _ = auth.Serve() + }) +} + +type fakeAuthServer struct { + *proto.UnimplementedAuthServiceServer + listener net.Listener + srv *grpc.Server +} + +func newFakeAuthServer(t *testing.T, conn net.Conn) *fakeAuthServer { + f := &fakeAuthServer{ + listener: newOneShotListener(conn), + UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, + srv: grpc.NewServer(), + } + + t.Cleanup(f.Stop) + proto.RegisterAuthServiceServer(f.srv, f) + return f +} + +func (f *fakeAuthServer) Ping(context.Context, *proto.PingRequest) (*proto.PingResponse, error) { + return &proto.PingResponse{ + ClusterName: "test", + ServerVersion: "1.0.0", + IsBoring: true, + }, nil +} + +func (f *fakeAuthServer) Serve() error { + return f.srv.Serve(f.listener) +} + +func (f *fakeAuthServer) Stop() { + _ = f.listener.Close() + f.srv.Stop() +} + +type oneShotListener struct { + conn net.Conn + closedCh chan struct{} + listenedCh chan struct{} +} + +func newOneShotListener(conn net.Conn) oneShotListener { + return oneShotListener{ + conn: conn, + closedCh: make(chan struct{}), + listenedCh: make(chan struct{}), + } +} + +func (l oneShotListener) Accept() (net.Conn, error) { + select { + case <-l.listenedCh: + <-l.closedCh + return nil, net.ErrClosed + default: + close(l.listenedCh) + return l.conn, nil + } +} + +func (l oneShotListener) Close() error { + select { + case <-l.closedCh: + default: + close(l.closedCh) + } + + return nil +} + +func (l oneShotListener) Addr() net.Addr { + return addr("127.0.0.1") +} + +func newSSHServer(t *testing.T, cfg fakeSSHServerConfig) *fakeSSHServer { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + srv := &fakeSSHServer{ + listener: listener, + cfg: cfg, + } + + go srv.run() + + t.Cleanup(func() { require.NoError(t, srv.Stop()) }) + return srv +} + +type fakeProxy struct { + *fakeSSHServer +} + +func newFakeProxy(t *testing.T, sshHandler sshHandler) *fakeProxy { + cSigner := generateSigner(t) + hSigner := generateSigner(t) + + sshConfig := &ssh.ServerConfig{ + NoClientAuth: true, + ServerVersion: "SSH-2.0-Teleport", + } + sshConfig.AddHostKey(hSigner) + + sshSrv := newSSHServer(t, fakeSSHServerConfig{ + config: sshConfig, + handler: sshHandler, + cSigner: cSigner, + hSigner: hSigner, + }) + + return &fakeProxy{ + fakeSSHServer: sshSrv, + } +} + +func (f *fakeProxy) clientConfig(t *testing.T) ClientConfig { + return ClientConfig{ + ProxyWebAddress: "127.0.0.1", + ProxySSHAddress: "127.0.0.1", + ClusterName: "test", + SSHDialer: SSHDialerFunc(func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { + conn, chans, reqs, err := f.fakeSSHServer.newClientConn() + if err != nil { + return nil, err + } + + clt := &tracessh.Client{Client: ssh.NewClient(conn, chans, reqs)} + t.Cleanup(func() { _ = clt.Close() }) + return clt, err + }), + SSHConfig: f.fakeSSHServer.clientConfig(), + } +} + +func TestNewClient(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tests := []struct { + name string + sshHandler sshHandler + assertion func(t *testing.T, clt *Client, err error) + }{ + { + name: "no grpc server and ssh server", + sshHandler: discardHandler, + assertion: func(t *testing.T, clt *Client, err error) { + require.NoError(t, err) + require.NotNil(t, clt) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + if clt != nil { + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + } + test.assertion(t, clt, err) + }) + } +} + +func TestClient_ClusterDetails(t *testing.T) { + t.Parallel() + ctx := context.Background() + + tests := []struct { + name string + sshHandler sshHandler + assertion func(t *testing.T, details ClusterDetails, err error) + }{ + { + name: "cluster details via ssh", + sshHandler: echoHandler(sshDetails{ + RecordingProxy: true, + FIPSEnabled: true, + }), + assertion: func(t *testing.T, details ClusterDetails, err error) { + require.NoError(t, err) + require.True(t, details.FIPS) + }, + }, + { + name: "cluster details via ssh fails", + sshHandler: discardHandler, + assertion: func(t *testing.T, details ClusterDetails, err error) { + require.Error(t, err) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + details, err := clt.ClusterDetails(ctx) + test.assertion(t, details, err) + }) + } +} + +func TestClient_DialHost(t *testing.T) { + t.Parallel() + ctx := context.Background() + + tests := []struct { + name string + sshHandler sshHandler + keyring agent.ExtendedAgent + assertion func(t *testing.T, conn net.Conn, details ClusterDetails, err error) + }{ + { + name: "ssh connection fails", + sshHandler: discardHandler, + assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) { + require.Error(t, err) + require.Nil(t, conn) + require.False(t, details.FIPS) + }, + }, + { + name: "ssh connection established", + sshHandler: echoHandler(sshDetails{RecordingProxy: false, FIPSEnabled: true}), + assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + require.True(t, details.FIPS) + + // test that the server echos data back over the connection + msg := []byte("hello123") + n, err := conn.Write(msg) + require.NoError(t, err) + require.Equal(t, len(msg), n) + + out := make([]byte, len(msg)) + n, err = conn.Read(out) + require.NoError(t, err) + require.Equal(t, len(msg), n) + require.Equal(t, msg, out) + + require.NoError(t, conn.Close()) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + conn, details, err := clt.DialHost(ctx, "test", "cluster", test.keyring) + test.assertion(t, conn, details, err) + }) + } +} + +func TestClient_DialCluster(t *testing.T) { + t.Parallel() + ctx := context.Background() + + tests := []struct { + name string + authCfg func(config *client.Config) + sshHandler sshHandler + keyring agent.ExtendedAgent + assertion func(t *testing.T, clt *client.Client, err error) + }{ + { + name: "ssh connection fails", + authCfg: func(config *client.Config) { + config.DialTimeout = 500 * time.Millisecond // speed up dial failure + }, + sshHandler: discardHandler, + assertion: func(t *testing.T, clt *client.Client, err error) { + require.Error(t, err) + require.Nil(t, clt) + }, + }, + { + name: "ssh connection established", + authCfg: func(config *client.Config) {}, + sshHandler: authHandler(t), + assertion: func(t *testing.T, clt *client.Client, err error) { + require.NoError(t, err) + require.NotNil(t, clt) + + expected := &proto.PingResponse{ + ClusterName: "test", + ServerVersion: "1.0.0", + IsBoring: true, + } + + resp, err := clt.Ping(ctx) + require.NoError(t, err) + require.Empty(t, cmp.Diff(expected, resp, protocmp.Transform())) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + authCfg := clt.ClientConfig(ctx, "cluster") + authCfg.DialOpts = []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithReturnConnectionError(), + grpc.WithDisableRetry(), + grpc.FailOnNonTempDialError(true), + } + authCfg.Credentials = []client.Credentials{insecureCredentials{}} + authCfg.DialTimeout = 3 * time.Second + + test.authCfg(&authCfg) + + authClt, err := client.New(ctx, authCfg) + if authClt != nil { + t.Cleanup(func() { + require.NoError(t, authClt.Close()) + }) + } + test.assertion(t, authClt, err) + }) + } +} + +func TestClient_SSHConfig(t *testing.T) { + t.Parallel() + + proxy := newFakeProxy(t, discardHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(context.Background(), cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + const user = "test-user" + sshConfig := clt.SSHConfig(user) + require.NotNil(t, sshConfig) + require.Equal(t, user, sshConfig.User) + require.Empty(t, cmp.Diff(cfg.SSHConfig, sshConfig, cmpopts.IgnoreFields(ssh.ClientConfig{}, "User", "Auth", "HostKeyCallback"))) +} diff --git a/api/client/proxy/session_conn.go b/api/client/proxy/session_conn.go new file mode 100644 index 0000000000000..46b3ef5a50ed4 --- /dev/null +++ b/api/client/proxy/session_conn.go @@ -0,0 +1,102 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "io" + "net" + "sync" + "time" + + "github.com/gravitational/trace" + + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" +) + +// addr is a [net.Addr] implementation for static tcp addresses. +type addr string + +func (a addr) Network() string { + return "tcp" +} + +func (a addr) String() string { + return string(a) +} + +// sessionConn is a [net.Conn] implementation that reads and writes data +// over the standard input and standard output, respectively, of a [tracessh.Session]. +type sessionConn struct { + io.Reader + session *tracessh.Session + localAddr net.Addr + remoteAddr net.Addr + + mu sync.Mutex + w io.WriteCloser +} + +// newSessionConn creates a [net.Conn] for over the provided [tracessh.Session]. +func newSessionConn(session *tracessh.Session, local, remote string) (*sessionConn, error) { + sessionW, err := session.StdinPipe() + if err != nil { + return nil, trace.Wrap(err) + } + + sessionR, err := session.StdoutPipe() + if err != nil { + return nil, trace.Wrap(err) + } + + return &sessionConn{ + session: session, + Reader: sessionR, + w: sessionW, + localAddr: addr(local), + remoteAddr: addr(remote), + }, nil +} + +func (s *sessionConn) Write(b []byte) (n int, err error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.w.Write(b) +} + +func (s *sessionConn) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + return trace.NewAggregate(s.w.Close(), s.session.Close()) +} + +func (s *sessionConn) LocalAddr() net.Addr { + return s.localAddr +} + +func (s *sessionConn) RemoteAddr() net.Addr { + return s.remoteAddr +} + +func (s *sessionConn) SetDeadline(time.Time) error { + return nil +} + +func (s *sessionConn) SetReadDeadline(time.Time) error { + return nil +} + +func (s *sessionConn) SetWriteDeadline(time.Time) error { + return nil +} diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index a4ddebd70b797..7612f86a97252 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -37,7 +37,6 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/net/http/httpproxy" - "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" @@ -93,7 +92,7 @@ func newWebClient(cfg *Config) (*http.Client, error) { return nil, trace.Wrap(err) } - rt := proxy.NewHTTPRoundTripper(&http.Transport{ + rt := utils.NewHTTPRoundTripper(&http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: cfg.Insecure, RootCAs: cfg.Pool, diff --git a/api/client/proxy/proxy.go b/api/utils/proxy.go similarity index 99% rename from api/client/proxy/proxy.go rename to api/utils/proxy.go index b7f85b682274c..2f25b5ad5f906 100644 --- a/api/client/proxy/proxy.go +++ b/api/utils/proxy.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package proxy +package utils import ( "net/http" diff --git a/api/client/proxy/proxy_test.go b/api/utils/proxy_test.go similarity index 99% rename from api/client/proxy/proxy_test.go rename to api/utils/proxy_test.go index ab87e875cb14c..6cacc9f7b0ffa 100644 --- a/api/client/proxy/proxy_test.go +++ b/api/utils/proxy_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package proxy +package utils import ( "crypto/tls" diff --git a/lib/client/https_client.go b/lib/client/https_client.go index 795186aceab62..c209868a679d0 100644 --- a/lib/client/https_client.go +++ b/lib/client/https_client.go @@ -28,7 +28,6 @@ import ( "golang.org/x/net/http/httpproxy" "github.com/gravitational/teleport" - apiproxy "github.com/gravitational/teleport/api/client/proxy" tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/httplib" @@ -41,7 +40,7 @@ func NewInsecureWebClient() *http.Client { func newClient(insecure bool, pool *x509.CertPool, extraHeaders map[string]string) *http.Client { return &http.Client{ - Transport: tracehttp.NewTransport(apiproxy.NewHTTPRoundTripper(httpTransport(insecure, pool), extraHeaders)), + Transport: tracehttp.NewTransport(apiutils.NewHTTPRoundTripper(httpTransport(insecure, pool), extraHeaders)), } } diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index e2c0900a86046..8cbc75b393f38 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -29,9 +29,9 @@ import ( "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" - apiproxy "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/utils" ) @@ -288,7 +288,7 @@ func WithInsecureSkipTLSVerify(insecure bool) DialerOptionFunc { // server directly. func DialerFromEnvironment(addr string, opts ...DialerOptionFunc) Dialer { // Try and get proxy addr from the environment. - proxyURL := apiproxy.GetProxyURL(addr) + proxyURL := apiutils.GetProxyURL(addr) var options dialerOptions for _, opt := range opts { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 91cefcbd0bb06..a08b9463c77a9 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -37,11 +37,11 @@ import ( "golang.org/x/crypto/ssh/agent" "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/db/dbcmd" @@ -233,7 +233,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy // if sp.tlsRouting is true, remoteProxyAddr is the ALPN listener port. // if it is false, then remoteProxyAddr is the SSH proxy port. remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) - httpsProxy := proxy.GetProxyURL(remoteProxyAddr) + httpsProxy := apiutils.GetProxyURL(remoteProxyAddr) // If HTTPS_PROXY is configured, we need to open a TCP connection via // the specified HTTPS Proxy, otherwise, we can just open a plain TCP diff --git a/tool/tsh/resolve_default_addr.go b/tool/tsh/resolve_default_addr.go index 82ebfea53de32..edc93a06a8aae 100644 --- a/tool/tsh/resolve_default_addr.go +++ b/tool/tsh/resolve_default_addr.go @@ -30,8 +30,8 @@ import ( "github.com/gravitational/trace" - "github.com/gravitational/teleport/api/client/proxy" tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" + "github.com/gravitational/teleport/api/utils" ) type raceResult struct { @@ -121,7 +121,7 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in InsecureSkipVerify: insecure, }, Proxy: func(req *http.Request) (*url.URL, error) { - return proxy.GetProxyURL(req.URL.String()), nil + return utils.GetProxyURL(req.URL.String()), nil }, }, ), From 7988ca0faa399d0dc1ef2aaf3bf6b038dbae4be3 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Mon, 3 Apr 2023 08:34:43 -0400 Subject: [PATCH 2/2] Make `proxy.Client` infer the cluster name from Proxy (#23644) Instead of relying on users to provide the cluster name, the client now determines the cluster name by inspecting the certificate presented by the Proxy during the TLS or SSH handshake. This is required when connecting to a Proxy via a jump host since the name of the cluster may not match the currently logged in cluster. This is achieved by leveraging a custom `credentials.TransportCredentials` when connecting via gRPC and a custom `ssh.HostKeyCallback` when connecting SSH. --- api/client/proxy/client.go | 86 +++++++++++++++++++++++++++++---- api/client/proxy/client_test.go | 81 ++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 11 deletions(-) diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 9d2b97cf13682..317c90d3b6ae2 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -20,6 +20,7 @@ import ( "io" "net" "strings" + "sync/atomic" "time" "github.com/gravitational/trace" @@ -55,9 +56,6 @@ type ClientConfig struct { ProxySSHAddress string // TLSRoutingEnabled indicates if the cluster is using TLS Routing. TLSRoutingEnabled bool - // ClusterName is the name of the Teleport cluster that the client - // will be connected to. - ClusterName string // TLSConfig contains the tls.Config required for mTLS connections. TLSConfig *tls.Config // SSHDialer allows callers to control how a [tracessh.Client] is created. @@ -71,6 +69,8 @@ type ClientConfig struct { clientCreds func() client.Credentials } +// CheckAndSetDefaults ensures required options are present and +// sets the default value of any that are omitted. func (c *ClientConfig) CheckAndSetDefaults() error { if c.ProxyWebAddress == "" { return trace.BadParameter("missing required parameter ProxyWebAddress") @@ -78,9 +78,6 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.ProxySSHAddress == "" { return trace.BadParameter("missing required parameter ProxySSHAddress") } - if c.ClusterName == "" { - return trace.BadParameter("missing required parameter ClusterName") - } if c.SSHDialer == nil { return trace.BadParameter("missing required parameter SSHDialer") } @@ -104,6 +101,8 @@ func (c *ClientConfig) CheckAndSetDefaults() error { return nil } +// insecureCredentials implements [client.Credentials] and is used by tests +// to connect to the Auth server without mTLS. type insecureCredentials struct{} func (mc insecureCredentials) Dialer(client.Config) (client.ContextDialer, error) { @@ -125,6 +124,9 @@ type Client struct { cfg *ClientConfig // sshClient is the established SSH connection to the Proxy. sshClient *tracessh.Client + // clusterName as determined by inspecting the certificate presented by + // the Proxy during the connection handshake. + clusterName *clusterName } // NewClient creates a new Client that attempts to connect to the SSH @@ -138,19 +140,83 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { return clt, trace.Wrap(err) } +// clusterName stores the name of the cluster +// in a protected manner which allows it to +// be set during handshakes with the server. +type clusterName struct { + name atomic.Pointer[string] +} + +func (c *clusterName) get() string { + name := c.name.Load() + if name != nil { + return *name + } + return "" +} + +func (c *clusterName) set(name string) { + c.name.CompareAndSwap(nil, &name) +} + +// teleportAuthority is the extension set by the server +// which contains the name of the cluster it is in. +const teleportAuthority = "x-teleport-authority" + +// clusterCallback is a [ssh.HostKeyCallback] that obtains the name +// of the cluster being connected to from the certificate presented by the server. +// This allows the client to determine the cluster name when using jump hosts. +func clusterCallback(c *clusterName, wrapped ssh.HostKeyCallback) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := wrapped(hostname, remote, key); err != nil { + return trace.Wrap(err) + } + + cert, ok := key.(*ssh.Certificate) + if !ok { + return nil + } + + clusterName, ok := cert.Permissions.Extensions[teleportAuthority] + if ok { + c.set(clusterName) + } + + return nil + } +} + // newSSHClient creates a Client that is connected via SSH. func newSSHClient(ctx context.Context, cfg *ClientConfig) (*Client, error) { - clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, cfg.SSHConfig) + c := &clusterName{} + clientCfg := &ssh.ClientConfig{ + User: cfg.SSHConfig.User, + Auth: cfg.SSHConfig.Auth, + HostKeyCallback: clusterCallback(c, cfg.SSHConfig.HostKeyCallback), + BannerCallback: cfg.SSHConfig.BannerCallback, + ClientVersion: cfg.SSHConfig.ClientVersion, + HostKeyAlgorithms: cfg.SSHConfig.HostKeyAlgorithms, + Timeout: cfg.SSHConfig.Timeout, + } + + clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, clientCfg) if err != nil { return nil, trace.Wrap(err) } return &Client{ - cfg: cfg, - sshClient: clt, + cfg: cfg, + sshClient: clt, + clusterName: c, }, nil } +// ClusterName returns the name of the cluster that the +// connected Proxy is a member of. +func (c *Client) ClusterName() string { + return c.clusterName.get() +} + // Close attempts to close the SSH connections. func (c *Client) Close() error { return trace.Wrap(c.sshClient.Close()) @@ -322,7 +388,7 @@ func dialSSH(ctx context.Context, clt *tracessh.Client, proxyAddress, targetAddr // read the stderr output from the failed SSH session and append // it to the end of our own message: serverErrorMsg, _ := io.ReadAll(sessionError) - return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %v. %v", targetAddress, serverErrorMsg, err) + return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %s. %v", targetAddress, serverErrorMsg, err) } return conn, nil diff --git a/api/client/proxy/client_test.go b/api/client/proxy/client_test.go index 9047faea17325..4c2835dce1bda 100644 --- a/api/client/proxy/client_test.go +++ b/api/client/proxy/client_test.go @@ -299,7 +299,6 @@ func (f *fakeProxy) clientConfig(t *testing.T) ClientConfig { return ClientConfig{ ProxyWebAddress: "127.0.0.1", ProxySSHAddress: "127.0.0.1", - ClusterName: "test", SSHDialer: SSHDialerFunc(func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { conn, chans, reqs, err := f.fakeSSHServer.newClientConn() if err != nil { @@ -541,3 +540,83 @@ func TestClient_SSHConfig(t *testing.T) { require.Equal(t, user, sshConfig.User) require.Empty(t, cmp.Diff(cfg.SSHConfig, sshConfig, cmpopts.IgnoreFields(ssh.ClientConfig{}, "User", "Auth", "HostKeyCallback"))) } + +type fakePublicKey struct{} + +func (f fakePublicKey) Type() string { + return "test" +} + +func (f fakePublicKey) Marshal() []byte { + return nil +} + +func (f fakePublicKey) Verify(data []byte, sig *ssh.Signature) error { + return trace.NotImplemented("") +} + +func TestClusterCallback(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + hostKeyCB ssh.HostKeyCallback + publicKey ssh.PublicKey + expectedClusterName string + errAssertion require.ErrorAssertionFunc + }{ + { + name: "handshake failure", + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return context.Canceled + }, + errAssertion: require.Error, + }, + { + name: "invalid certificate", + publicKey: fakePublicKey{}, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + { + name: "no authority present", + publicKey: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + + { + name: "cluster name presented", + expectedClusterName: "test-cluster", + publicKey: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleportAuthority: "test-cluster", + }, + }, + }, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + c := &clusterName{} + err := clusterCallback(c, test.hostKeyCB)("test", addr("127.0.0.1"), test.publicKey) + test.errAssertion(t, err) + require.Equal(t, test.expectedClusterName, c.get()) + + }) + } +}