diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index ba9e07813c20c..bb53ec64500f1 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -26,11 +26,11 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" ) @@ -81,7 +81,7 @@ func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc { func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration) ContextDialer { return tracedDialer(ctx, func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := newDirectDialer(keepAlivePeriod, dialTimeout) - if proxyURL := proxy.GetProxyURL(addr); proxyURL != nil { + if proxyURL := utils.GetProxyURL(addr); proxyURL != nil { return DialProxyWithDialer(ctx, proxyURL, addr, dialer) } return dialer.DialContext(ctx, network, addr) diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go new file mode 100644 index 0000000000000..317c90d3b6ae2 --- /dev/null +++ b/api/client/proxy/client.go @@ -0,0 +1,395 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "crypto/tls" + "io" + "net" + "strings" + "sync/atomic" + "time" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/gravitational/teleport/api/breaker" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/defaults" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" +) + +// SSHDialer provides a mechanism to create a ssh client. +type SSHDialer interface { + // Dial establishes a client connection to an SSH server. + Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) +} + +// SSHDialerFunc implements SSHDialer +type SSHDialerFunc func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) + +// Dial calls f(ctx, network, addr, config). +func (f SSHDialerFunc) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { + return f(ctx, network, addr, config) +} + +// ClientConfig contains configuration needed for a Client +// to be able to connect to the cluster. +type ClientConfig struct { + // ProxyWebAddress is the address of the Proxy Web server. + ProxyWebAddress string + // ProxySSHAddress is the address of the Proxy SSH server. + ProxySSHAddress string + // TLSRoutingEnabled indicates if the cluster is using TLS Routing. + TLSRoutingEnabled bool + // TLSConfig contains the tls.Config required for mTLS connections. + TLSConfig *tls.Config + // SSHDialer allows callers to control how a [tracessh.Client] is created. + SSHDialer SSHDialer + // SSHConfig is the [ssh.ClientConfig] used to connect to the Proxy SSH server. + SSHConfig *ssh.ClientConfig + // DialTimeout defines how long to attempt dialing before timing out. + DialTimeout time.Duration + + // The client credentials to use when establishing the connection to auth. + clientCreds func() client.Credentials +} + +// CheckAndSetDefaults ensures required options are present and +// sets the default value of any that are omitted. +func (c *ClientConfig) CheckAndSetDefaults() error { + if c.ProxyWebAddress == "" { + return trace.BadParameter("missing required parameter ProxyWebAddress") + } + if c.ProxySSHAddress == "" { + return trace.BadParameter("missing required parameter ProxySSHAddress") + } + if c.SSHDialer == nil { + return trace.BadParameter("missing required parameter SSHDialer") + } + if c.SSHConfig == nil { + return trace.BadParameter("missing required parameter SSHConfig") + } + if c.DialTimeout <= 0 { + c.DialTimeout = defaults.DefaultIOTimeout + } + + if c.TLSConfig != nil { + c.clientCreds = func() client.Credentials { + return client.LoadTLS(c.TLSConfig.Clone()) + } + } else { + c.clientCreds = func() client.Credentials { + return insecureCredentials{} + } + } + + return nil +} + +// insecureCredentials implements [client.Credentials] and is used by tests +// to connect to the Auth server without mTLS. +type insecureCredentials struct{} + +func (mc insecureCredentials) Dialer(client.Config) (client.ContextDialer, error) { + return nil, trace.NotImplemented("no dialer") +} + +func (mc insecureCredentials) TLSConfig() (*tls.Config, error) { + return nil, nil +} + +func (mc insecureCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { + return nil, trace.NotImplemented("no ssh config") +} + +// Client is a client to the Teleport Proxy SSH server on behalf of a user. +type Client struct { + // cfg are the user provided configuration parameters required to + // connect and interact with the Proxy. + cfg *ClientConfig + // sshClient is the established SSH connection to the Proxy. + sshClient *tracessh.Client + // clusterName as determined by inspecting the certificate presented by + // the Proxy during the connection handshake. + clusterName *clusterName +} + +// NewClient creates a new Client that attempts to connect to the SSH +// server being served by the Proxy SSH port by default. +func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + clt, err := newSSHClient(ctx, &cfg) + return clt, trace.Wrap(err) +} + +// clusterName stores the name of the cluster +// in a protected manner which allows it to +// be set during handshakes with the server. +type clusterName struct { + name atomic.Pointer[string] +} + +func (c *clusterName) get() string { + name := c.name.Load() + if name != nil { + return *name + } + return "" +} + +func (c *clusterName) set(name string) { + c.name.CompareAndSwap(nil, &name) +} + +// teleportAuthority is the extension set by the server +// which contains the name of the cluster it is in. +const teleportAuthority = "x-teleport-authority" + +// clusterCallback is a [ssh.HostKeyCallback] that obtains the name +// of the cluster being connected to from the certificate presented by the server. +// This allows the client to determine the cluster name when using jump hosts. +func clusterCallback(c *clusterName, wrapped ssh.HostKeyCallback) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := wrapped(hostname, remote, key); err != nil { + return trace.Wrap(err) + } + + cert, ok := key.(*ssh.Certificate) + if !ok { + return nil + } + + clusterName, ok := cert.Permissions.Extensions[teleportAuthority] + if ok { + c.set(clusterName) + } + + return nil + } +} + +// newSSHClient creates a Client that is connected via SSH. +func newSSHClient(ctx context.Context, cfg *ClientConfig) (*Client, error) { + c := &clusterName{} + clientCfg := &ssh.ClientConfig{ + User: cfg.SSHConfig.User, + Auth: cfg.SSHConfig.Auth, + HostKeyCallback: clusterCallback(c, cfg.SSHConfig.HostKeyCallback), + BannerCallback: cfg.SSHConfig.BannerCallback, + ClientVersion: cfg.SSHConfig.ClientVersion, + HostKeyAlgorithms: cfg.SSHConfig.HostKeyAlgorithms, + Timeout: cfg.SSHConfig.Timeout, + } + + clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, clientCfg) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Client{ + cfg: cfg, + sshClient: clt, + clusterName: c, + }, nil +} + +// ClusterName returns the name of the cluster that the +// connected Proxy is a member of. +func (c *Client) ClusterName() string { + return c.clusterName.get() +} + +// Close attempts to close the SSH connections. +func (c *Client) Close() error { + return trace.Wrap(c.sshClient.Close()) +} + +// SSHConfig returns the [ssh.ClientConfig] for the provided user which +// should be used when creating a [tracessh.Client] with the returned +// [net.Conn] from [Client.DialHost]. +func (c *Client) SSHConfig(user string) *ssh.ClientConfig { + return &ssh.ClientConfig{ + Config: c.cfg.SSHConfig.Config, + User: user, + Auth: c.cfg.SSHConfig.Auth, + HostKeyCallback: c.cfg.SSHConfig.HostKeyCallback, + BannerCallback: c.cfg.SSHConfig.BannerCallback, + ClientVersion: c.cfg.SSHConfig.ClientVersion, + HostKeyAlgorithms: c.cfg.SSHConfig.HostKeyAlgorithms, + Timeout: c.cfg.SSHConfig.Timeout, + } +} + +// ClusterDetails provide cluster configuration +// details as known by the connected Proxy. +type ClusterDetails struct { + // FIPS dictates whether FIPS mode is enabled. + FIPS bool +} + +// ClientConfig returns a [client.Config] that may be used to connect to the +// Auth server in the provided cluster via [client.New] or similar. The [client.Config] +// returned will have the correct credentials and dialer set based on the ClientConfig +// that was provided to create this Client. +func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config { + if c.cfg.TLSRoutingEnabled { + return client.Config{ + Context: ctx, + Addrs: []string{c.cfg.ProxyWebAddress}, + Credentials: []client.Credentials{c.cfg.clientCreds()}, + ALPNSNIAuthDialClusterName: cluster, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + } + } + + return client.Config{ + Context: ctx, + Credentials: []client.Credentials{c.cfg.clientCreds()}, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + DialInBackground: true, + Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) { + // Don't dial if the context has timed out. + select { + case <-dialCtx.Done(): + return nil, dialCtx.Err() + default: + } + + conn, err := dialSSH(dialCtx, c.sshClient, c.cfg.ProxySSHAddress, "@"+cluster, nil) + return conn, trace.Wrap(err) + }), + } +} + +// DialHost establishes a connection to the `target` in cluster named `cluster`. If a keyring +// is provided it will only be forwarded if proxy recording mode is enabled in the cluster. +func (c *Client) DialHost(ctx context.Context, target, cluster string, keyring agent.ExtendedAgent) (net.Conn, ClusterDetails, error) { + conn, details, err := c.dialHostSSH(ctx, target, cluster, keyring) + return conn, details, trace.Wrap(err) +} + +// dialHostSSH connects to the target via SSH. To match backwards compatibility the +// cluster details are retrieved from the Proxy SSH server via a clusterDetailsRequest +// request to determine if the keyring should be forwarded. +func (c *Client) dialHostSSH(ctx context.Context, target, cluster string, keyring agent.ExtendedAgent) (net.Conn, ClusterDetails, error) { + details, err := c.clusterDetailsSSH(ctx) + if err != nil { + return nil, ClusterDetails{FIPS: details.FIPSEnabled}, trace.Wrap(err) + } + + // Prevent forwarding the keychain if the proxy is + // not doing the recording. + if !details.RecordingProxy { + keyring = nil + } + + conn, err := dialSSH(ctx, c.sshClient, c.cfg.ProxySSHAddress, target+"@"+cluster, keyring) + return conn, ClusterDetails{FIPS: details.FIPSEnabled}, trace.Wrap(err) +} + +// ClusterDetails retrieves cluster information as seen by the Proxy. +func (c *Client) ClusterDetails(ctx context.Context) (ClusterDetails, error) { + details, err := c.clusterDetailsSSH(ctx) + return ClusterDetails{FIPS: details.FIPSEnabled}, trace.Wrap(err) +} + +// sshDetails is the response from a clusterDetailsRequest. +type sshDetails struct { + RecordingProxy bool + FIPSEnabled bool +} + +const clusterDetailsRequest = "cluster-details@goteleport.com" + +// clusterDetailsSSH retrieves the cluster details via a clusterDetailsRequest. +func (c *Client) clusterDetailsSSH(ctx context.Context) (sshDetails, error) { + ok, resp, err := c.sshClient.SendRequest(ctx, clusterDetailsRequest, true, nil) + if err != nil { + return sshDetails{}, trace.Wrap(err) + } + + if !ok { + return sshDetails{}, trace.ConnectionProblem(nil, "failed to get cluster details") + } + + var details sshDetails + if err := ssh.Unmarshal(resp, &details); err != nil { + return sshDetails{}, trace.Wrap(err) + } + + return details, trace.Wrap(err) +} + +// dialSSH creates a SSH session to the target address and proxies a [net.Conn] +// over the standard input and output of the session. +func dialSSH(ctx context.Context, clt *tracessh.Client, proxyAddress, targetAddress string, keyring agent.ExtendedAgent) (_ net.Conn, err error) { + session, err := clt.NewSession(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + defer func() { + if err != nil { + _ = session.Close() + } + }() + + conn, err := newSessionConn(session, proxyAddress, targetAddress) + if err != nil { + return nil, trace.Wrap(err) + } + + defer func() { + if err != nil { + _ = conn.Close() + } + }() + + sessionError, err := session.StderrPipe() + if err != nil { + return nil, trace.Wrap(err) + } + + // If a keyring was provided then set up agent forwarding. + if keyring != nil { + // Add a handler to receive requests on the auth-agent@openssh.com channel. If there is + // already a handler it's safe to ignore the error because we only need one active handler + // to process requests. + err = agent.ForwardToAgent(clt.Client, keyring) + if err != nil && !strings.Contains(err.Error(), "agent: already have handler for") { + return nil, trace.Wrap(err) + } + + err = agent.RequestAgentForwarding(session.Session) + if err != nil { + return nil, trace.Wrap(err) + } + } + + if err := session.RequestSubsystem(ctx, "proxy:"+targetAddress); err != nil { + // read the stderr output from the failed SSH session and append + // it to the end of our own message: + serverErrorMsg, _ := io.ReadAll(sessionError) + return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %s. %v", targetAddress, serverErrorMsg, err) + } + + return conn, nil +} diff --git a/api/client/proxy/client_test.go b/api/client/proxy/client_test.go new file mode 100644 index 0000000000000..4c2835dce1bda --- /dev/null +++ b/api/client/proxy/client_test.go @@ -0,0 +1,622 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "io" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/testing/protocmp" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/utils/sshutils" +) + +type fakeSSHServer struct { + listener net.Listener + cfg fakeSSHServerConfig +} + +func (s *fakeSSHServer) run() { + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + + go func() { + sconn, chans, reqs, err := ssh.NewServerConn(conn, s.cfg.config) + if err != nil { + return + } + s.cfg.handler(sconn, chans, reqs) + }() + } +} + +func (s *fakeSSHServer) Stop() error { + return s.listener.Close() +} + +func generateSigner(t *testing.T) ssh.Signer { + private, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(private), + } + + privatePEM := pem.EncodeToMemory(block) + signer, err := ssh.ParsePrivateKey(privatePEM) + require.NoError(t, err) + return signer +} + +func (s *fakeSSHServer) clientConfig() *ssh.ClientConfig { + return &ssh.ClientConfig{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cfg.cSigner)}, + HostKeyCallback: ssh.FixedHostKey(s.cfg.hSigner.PublicKey()), + } +} + +func (s *fakeSSHServer) newClientConn() (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + conn, err := net.Dial("tcp", s.listener.Addr().String()) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + sconn, nc, r, err := ssh.NewClientConn(conn, "", s.clientConfig()) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sconn, nc, r, nil +} + +type sshHandler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request) + +type fakeSSHServerConfig struct { + config *ssh.ServerConfig + handler sshHandler + cSigner ssh.Signer + hSigner ssh.Signer +} + +func discardHandler(conn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + defer func() { _ = conn.Close() }() + + go ssh.DiscardRequests(reqs) + + for ch := range chans { + _ = ch.Reject(ssh.Prohibited, "discard") + } +} + +func proxySubsystemHandler(details sshDetails, handleConn func(conn *ssh.ServerConn, ch ssh.Channel)) sshHandler { + return func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + defer func() { _ = conn.Close() }() + + go func() { + for req := range requests { + if req.Type == clusterDetailsRequest { + _ = req.Reply(true, ssh.Marshal(details)) + } + } + }() + + for nch := range channels { + if nch.ChannelType() != "session" { + _ = nch.Reject(ssh.UnknownChannelType, "unknown channel") + continue + } + + ch, reqs, err := nch.Accept() + if err != nil { + return + } + + go func() { + defer func() { _ = ch.Close() }() + + for req := range reqs { + ok := req.Type == "subsystem" + + if req.WantReply { + _ = req.Reply(ok, nil) + } + + if !ok { + continue + } + + handleConn(conn, ch) + } + }() + } + } +} + +func echoHandler(details sshDetails) sshHandler { + return proxySubsystemHandler(details, func(conn *ssh.ServerConn, ch ssh.Channel) { + _, _ = io.Copy(ch, ch) + }) +} + +func authHandler(t *testing.T) sshHandler { + return proxySubsystemHandler(sshDetails{}, func(conn *ssh.ServerConn, ch ssh.Channel) { + auth := newFakeAuthServer(t, sshutils.NewChConn(conn, ch)) + t.Cleanup(auth.Stop) + _ = auth.Serve() + }) +} + +type fakeAuthServer struct { + *proto.UnimplementedAuthServiceServer + listener net.Listener + srv *grpc.Server +} + +func newFakeAuthServer(t *testing.T, conn net.Conn) *fakeAuthServer { + f := &fakeAuthServer{ + listener: newOneShotListener(conn), + UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, + srv: grpc.NewServer(), + } + + t.Cleanup(f.Stop) + proto.RegisterAuthServiceServer(f.srv, f) + return f +} + +func (f *fakeAuthServer) Ping(context.Context, *proto.PingRequest) (*proto.PingResponse, error) { + return &proto.PingResponse{ + ClusterName: "test", + ServerVersion: "1.0.0", + IsBoring: true, + }, nil +} + +func (f *fakeAuthServer) Serve() error { + return f.srv.Serve(f.listener) +} + +func (f *fakeAuthServer) Stop() { + _ = f.listener.Close() + f.srv.Stop() +} + +type oneShotListener struct { + conn net.Conn + closedCh chan struct{} + listenedCh chan struct{} +} + +func newOneShotListener(conn net.Conn) oneShotListener { + return oneShotListener{ + conn: conn, + closedCh: make(chan struct{}), + listenedCh: make(chan struct{}), + } +} + +func (l oneShotListener) Accept() (net.Conn, error) { + select { + case <-l.listenedCh: + <-l.closedCh + return nil, net.ErrClosed + default: + close(l.listenedCh) + return l.conn, nil + } +} + +func (l oneShotListener) Close() error { + select { + case <-l.closedCh: + default: + close(l.closedCh) + } + + return nil +} + +func (l oneShotListener) Addr() net.Addr { + return addr("127.0.0.1") +} + +func newSSHServer(t *testing.T, cfg fakeSSHServerConfig) *fakeSSHServer { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + srv := &fakeSSHServer{ + listener: listener, + cfg: cfg, + } + + go srv.run() + + t.Cleanup(func() { require.NoError(t, srv.Stop()) }) + return srv +} + +type fakeProxy struct { + *fakeSSHServer +} + +func newFakeProxy(t *testing.T, sshHandler sshHandler) *fakeProxy { + cSigner := generateSigner(t) + hSigner := generateSigner(t) + + sshConfig := &ssh.ServerConfig{ + NoClientAuth: true, + ServerVersion: "SSH-2.0-Teleport", + } + sshConfig.AddHostKey(hSigner) + + sshSrv := newSSHServer(t, fakeSSHServerConfig{ + config: sshConfig, + handler: sshHandler, + cSigner: cSigner, + hSigner: hSigner, + }) + + return &fakeProxy{ + fakeSSHServer: sshSrv, + } +} + +func (f *fakeProxy) clientConfig(t *testing.T) ClientConfig { + return ClientConfig{ + ProxyWebAddress: "127.0.0.1", + ProxySSHAddress: "127.0.0.1", + SSHDialer: SSHDialerFunc(func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { + conn, chans, reqs, err := f.fakeSSHServer.newClientConn() + if err != nil { + return nil, err + } + + clt := &tracessh.Client{Client: ssh.NewClient(conn, chans, reqs)} + t.Cleanup(func() { _ = clt.Close() }) + return clt, err + }), + SSHConfig: f.fakeSSHServer.clientConfig(), + } +} + +func TestNewClient(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tests := []struct { + name string + sshHandler sshHandler + assertion func(t *testing.T, clt *Client, err error) + }{ + { + name: "no grpc server and ssh server", + sshHandler: discardHandler, + assertion: func(t *testing.T, clt *Client, err error) { + require.NoError(t, err) + require.NotNil(t, clt) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + if clt != nil { + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + } + test.assertion(t, clt, err) + }) + } +} + +func TestClient_ClusterDetails(t *testing.T) { + t.Parallel() + ctx := context.Background() + + tests := []struct { + name string + sshHandler sshHandler + assertion func(t *testing.T, details ClusterDetails, err error) + }{ + { + name: "cluster details via ssh", + sshHandler: echoHandler(sshDetails{ + RecordingProxy: true, + FIPSEnabled: true, + }), + assertion: func(t *testing.T, details ClusterDetails, err error) { + require.NoError(t, err) + require.True(t, details.FIPS) + }, + }, + { + name: "cluster details via ssh fails", + sshHandler: discardHandler, + assertion: func(t *testing.T, details ClusterDetails, err error) { + require.Error(t, err) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + details, err := clt.ClusterDetails(ctx) + test.assertion(t, details, err) + }) + } +} + +func TestClient_DialHost(t *testing.T) { + t.Parallel() + ctx := context.Background() + + tests := []struct { + name string + sshHandler sshHandler + keyring agent.ExtendedAgent + assertion func(t *testing.T, conn net.Conn, details ClusterDetails, err error) + }{ + { + name: "ssh connection fails", + sshHandler: discardHandler, + assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) { + require.Error(t, err) + require.Nil(t, conn) + require.False(t, details.FIPS) + }, + }, + { + name: "ssh connection established", + sshHandler: echoHandler(sshDetails{RecordingProxy: false, FIPSEnabled: true}), + assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + require.True(t, details.FIPS) + + // test that the server echos data back over the connection + msg := []byte("hello123") + n, err := conn.Write(msg) + require.NoError(t, err) + require.Equal(t, len(msg), n) + + out := make([]byte, len(msg)) + n, err = conn.Read(out) + require.NoError(t, err) + require.Equal(t, len(msg), n) + require.Equal(t, msg, out) + + require.NoError(t, conn.Close()) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + conn, details, err := clt.DialHost(ctx, "test", "cluster", test.keyring) + test.assertion(t, conn, details, err) + }) + } +} + +func TestClient_DialCluster(t *testing.T) { + t.Parallel() + ctx := context.Background() + + tests := []struct { + name string + authCfg func(config *client.Config) + sshHandler sshHandler + keyring agent.ExtendedAgent + assertion func(t *testing.T, clt *client.Client, err error) + }{ + { + name: "ssh connection fails", + authCfg: func(config *client.Config) { + config.DialTimeout = 500 * time.Millisecond // speed up dial failure + }, + sshHandler: discardHandler, + assertion: func(t *testing.T, clt *client.Client, err error) { + require.Error(t, err) + require.Nil(t, clt) + }, + }, + { + name: "ssh connection established", + authCfg: func(config *client.Config) {}, + sshHandler: authHandler(t), + assertion: func(t *testing.T, clt *client.Client, err error) { + require.NoError(t, err) + require.NotNil(t, clt) + + expected := &proto.PingResponse{ + ClusterName: "test", + ServerVersion: "1.0.0", + IsBoring: true, + } + + resp, err := clt.Ping(ctx) + require.NoError(t, err) + require.Empty(t, cmp.Diff(expected, resp, protocmp.Transform())) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proxy := newFakeProxy(t, test.sshHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + authCfg := clt.ClientConfig(ctx, "cluster") + authCfg.DialOpts = []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithReturnConnectionError(), + grpc.WithDisableRetry(), + grpc.FailOnNonTempDialError(true), + } + authCfg.Credentials = []client.Credentials{insecureCredentials{}} + authCfg.DialTimeout = 3 * time.Second + + test.authCfg(&authCfg) + + authClt, err := client.New(ctx, authCfg) + if authClt != nil { + t.Cleanup(func() { + require.NoError(t, authClt.Close()) + }) + } + test.assertion(t, authClt, err) + }) + } +} + +func TestClient_SSHConfig(t *testing.T) { + t.Parallel() + + proxy := newFakeProxy(t, discardHandler) + cfg := proxy.clientConfig(t) + + clt, err := NewClient(context.Background(), cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + const user = "test-user" + sshConfig := clt.SSHConfig(user) + require.NotNil(t, sshConfig) + require.Equal(t, user, sshConfig.User) + require.Empty(t, cmp.Diff(cfg.SSHConfig, sshConfig, cmpopts.IgnoreFields(ssh.ClientConfig{}, "User", "Auth", "HostKeyCallback"))) +} + +type fakePublicKey struct{} + +func (f fakePublicKey) Type() string { + return "test" +} + +func (f fakePublicKey) Marshal() []byte { + return nil +} + +func (f fakePublicKey) Verify(data []byte, sig *ssh.Signature) error { + return trace.NotImplemented("") +} + +func TestClusterCallback(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + hostKeyCB ssh.HostKeyCallback + publicKey ssh.PublicKey + expectedClusterName string + errAssertion require.ErrorAssertionFunc + }{ + { + name: "handshake failure", + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return context.Canceled + }, + errAssertion: require.Error, + }, + { + name: "invalid certificate", + publicKey: fakePublicKey{}, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + { + name: "no authority present", + publicKey: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + + { + name: "cluster name presented", + expectedClusterName: "test-cluster", + publicKey: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleportAuthority: "test-cluster", + }, + }, + }, + hostKeyCB: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + errAssertion: require.NoError, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + c := &clusterName{} + err := clusterCallback(c, test.hostKeyCB)("test", addr("127.0.0.1"), test.publicKey) + test.errAssertion(t, err) + require.Equal(t, test.expectedClusterName, c.get()) + + }) + } +} diff --git a/api/client/proxy/session_conn.go b/api/client/proxy/session_conn.go new file mode 100644 index 0000000000000..46b3ef5a50ed4 --- /dev/null +++ b/api/client/proxy/session_conn.go @@ -0,0 +1,102 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "io" + "net" + "sync" + "time" + + "github.com/gravitational/trace" + + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" +) + +// addr is a [net.Addr] implementation for static tcp addresses. +type addr string + +func (a addr) Network() string { + return "tcp" +} + +func (a addr) String() string { + return string(a) +} + +// sessionConn is a [net.Conn] implementation that reads and writes data +// over the standard input and standard output, respectively, of a [tracessh.Session]. +type sessionConn struct { + io.Reader + session *tracessh.Session + localAddr net.Addr + remoteAddr net.Addr + + mu sync.Mutex + w io.WriteCloser +} + +// newSessionConn creates a [net.Conn] for over the provided [tracessh.Session]. +func newSessionConn(session *tracessh.Session, local, remote string) (*sessionConn, error) { + sessionW, err := session.StdinPipe() + if err != nil { + return nil, trace.Wrap(err) + } + + sessionR, err := session.StdoutPipe() + if err != nil { + return nil, trace.Wrap(err) + } + + return &sessionConn{ + session: session, + Reader: sessionR, + w: sessionW, + localAddr: addr(local), + remoteAddr: addr(remote), + }, nil +} + +func (s *sessionConn) Write(b []byte) (n int, err error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.w.Write(b) +} + +func (s *sessionConn) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + return trace.NewAggregate(s.w.Close(), s.session.Close()) +} + +func (s *sessionConn) LocalAddr() net.Addr { + return s.localAddr +} + +func (s *sessionConn) RemoteAddr() net.Addr { + return s.remoteAddr +} + +func (s *sessionConn) SetDeadline(time.Time) error { + return nil +} + +func (s *sessionConn) SetReadDeadline(time.Time) error { + return nil +} + +func (s *sessionConn) SetWriteDeadline(time.Time) error { + return nil +} diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index a4ddebd70b797..7612f86a97252 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -37,7 +37,6 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/net/http/httpproxy" - "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" @@ -93,7 +92,7 @@ func newWebClient(cfg *Config) (*http.Client, error) { return nil, trace.Wrap(err) } - rt := proxy.NewHTTPRoundTripper(&http.Transport{ + rt := utils.NewHTTPRoundTripper(&http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: cfg.Insecure, RootCAs: cfg.Pool, diff --git a/api/client/proxy/proxy.go b/api/utils/proxy.go similarity index 99% rename from api/client/proxy/proxy.go rename to api/utils/proxy.go index b7f85b682274c..2f25b5ad5f906 100644 --- a/api/client/proxy/proxy.go +++ b/api/utils/proxy.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package proxy +package utils import ( "net/http" diff --git a/api/client/proxy/proxy_test.go b/api/utils/proxy_test.go similarity index 99% rename from api/client/proxy/proxy_test.go rename to api/utils/proxy_test.go index ab87e875cb14c..6cacc9f7b0ffa 100644 --- a/api/client/proxy/proxy_test.go +++ b/api/utils/proxy_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package proxy +package utils import ( "crypto/tls" diff --git a/lib/client/https_client.go b/lib/client/https_client.go index 795186aceab62..c209868a679d0 100644 --- a/lib/client/https_client.go +++ b/lib/client/https_client.go @@ -28,7 +28,6 @@ import ( "golang.org/x/net/http/httpproxy" "github.com/gravitational/teleport" - apiproxy "github.com/gravitational/teleport/api/client/proxy" tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/httplib" @@ -41,7 +40,7 @@ func NewInsecureWebClient() *http.Client { func newClient(insecure bool, pool *x509.CertPool, extraHeaders map[string]string) *http.Client { return &http.Client{ - Transport: tracehttp.NewTransport(apiproxy.NewHTTPRoundTripper(httpTransport(insecure, pool), extraHeaders)), + Transport: tracehttp.NewTransport(apiutils.NewHTTPRoundTripper(httpTransport(insecure, pool), extraHeaders)), } } diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index e2c0900a86046..8cbc75b393f38 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -29,9 +29,9 @@ import ( "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" - apiproxy "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/utils" ) @@ -288,7 +288,7 @@ func WithInsecureSkipTLSVerify(insecure bool) DialerOptionFunc { // server directly. func DialerFromEnvironment(addr string, opts ...DialerOptionFunc) Dialer { // Try and get proxy addr from the environment. - proxyURL := apiproxy.GetProxyURL(addr) + proxyURL := apiutils.GetProxyURL(addr) var options dialerOptions for _, opt := range opts { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 91cefcbd0bb06..a08b9463c77a9 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -37,11 +37,11 @@ import ( "golang.org/x/crypto/ssh/agent" "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/db/dbcmd" @@ -233,7 +233,7 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy // if sp.tlsRouting is true, remoteProxyAddr is the ALPN listener port. // if it is false, then remoteProxyAddr is the SSH proxy port. remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) - httpsProxy := proxy.GetProxyURL(remoteProxyAddr) + httpsProxy := apiutils.GetProxyURL(remoteProxyAddr) // If HTTPS_PROXY is configured, we need to open a TCP connection via // the specified HTTPS Proxy, otherwise, we can just open a plain TCP diff --git a/tool/tsh/resolve_default_addr.go b/tool/tsh/resolve_default_addr.go index 82ebfea53de32..edc93a06a8aae 100644 --- a/tool/tsh/resolve_default_addr.go +++ b/tool/tsh/resolve_default_addr.go @@ -30,8 +30,8 @@ import ( "github.com/gravitational/trace" - "github.com/gravitational/teleport/api/client/proxy" tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" + "github.com/gravitational/teleport/api/utils" ) type raceResult struct { @@ -121,7 +121,7 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in InsecureSkipVerify: insecure, }, Proxy: func(req *http.Request) (*url.URL, error) { - return proxy.GetProxyURL(req.URL.String()), nil + return utils.GetProxyURL(req.URL.String()), nil }, }, ),