diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index a8efd7fe880d8..b0299c8fa97bf 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -212,8 +212,10 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { } clt, sshErr := newSSHClient(ctx, &cfg) - if sshErr == nil { - return clt, nil + // Only aggregate errors if there was an issue dialing the grpc server so + // that helpers like trace.IsAccessDenied will still work. + if grpcErr == nil { + return clt, trace.Wrap(sshErr) } return nil, trace.NewAggregate(grpcErr, sshErr) @@ -241,7 +243,7 @@ func (c *clusterName) set(name string) { // clusterCredentials is a [credentials.TransportCredentials] implementation // that obtains the name of the cluster being connected to from the certificate // presented by the server. This allows the client to determine the cluster name when -// connecting via using jump hosts. +// connecting via jump hosts. type clusterCredentials struct { credentials.TransportCredentials clusterName *clusterName diff --git a/api/utils/grpc/stream/stream.go b/api/utils/grpc/stream/stream.go index bd9c40e31c954..54021755d273d 100644 --- a/api/utils/grpc/stream/stream.go +++ b/api/utils/grpc/stream/stream.go @@ -113,7 +113,7 @@ func (c *ReadWriter) Write(b []byte) (int, error) { } if err := c.source.Send(chunk); err != nil { - return sent, trace.ConnectionProblem(trail.FromGRPC(err), "failed to send on source") + return sent, trace.ConnectionProblem(trail.FromGRPC(err), "failed to send on source: %v", err) } sent += len(chunk) diff --git a/integration/integration_test.go b/integration/integration_test.go index b33ab0d56bb16..06a2330311130 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -212,25 +212,30 @@ func testDifferentPinnedIP(t *testing.T, suite *integrationTestSuite) { site := teleInstance.GetSiteAPI(helpers.Site) require.NotNil(t, site) + accessDenied := func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err, i...) + require.True(t, trace.IsAccessDenied(err), "expected an access denied error, got: %v", err) + } + testCases := []struct { - desc string - ip string - wantErr string + desc string + ip string + errAssertion require.ErrorAssertionFunc }{ { - desc: "Correct connecting IP", - ip: "127.0.0.1", - wantErr: "", + desc: "Correct connecting IP", + ip: "127.0.0.1", + errAssertion: require.NoError, }, { - desc: "Wrong connecting IPv4", - ip: "1.2.3.4", - wantErr: "ssh: unable to authenticate", + desc: "Wrong connecting IPv4", + ip: "1.2.3.4", + errAssertion: accessDenied, }, { - desc: "Wrong connecting IPv6", - ip: "1843:4545::12", - wantErr: "ssh: unable to authenticate", + desc: "Wrong connecting IPv6", + ip: "1843:4545::12", + errAssertion: accessDenied, }, } @@ -243,15 +248,9 @@ func testDifferentPinnedIP(t *testing.T, suite *integrationTestSuite) { SourceIP: test.ip, }) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err = cl.SSH(ctx, []string{"echo hi"}, false) - if test.wantErr != "" { - require.Error(t, err) - require.Contains(t, err.Error(), "ssh: unable to authenticate") - } else { - require.NoError(t, err) - } + test.errAssertion(t, cl.SSH(ctx, []string{"echo hi"}, false)) }) } } @@ -1279,6 +1278,29 @@ func testEscapeSequenceNoTrigger(t *testing.T, terminal *Terminal, sess <-chan e } } +type localAddr struct { + mu sync.Mutex + addr net.Addr +} + +func (a *localAddr) set(addr net.Addr) { + a.mu.Lock() + defer a.mu.Unlock() + + a.addr = addr +} + +func (a *localAddr) get() net.Addr { + a.mu.Lock() + defer a.mu.Unlock() + + if a.addr == nil { + return &utils.NetAddr{} + } + + return a.addr +} + // testIPPropagation makes sure that we can correctly propagate initial client IP observed by proxy. func testIPPropagation(t *testing.T, suite *integrationTestSuite) { tr := utils.NewTracer(utils.ThisFunction()).Start() @@ -1341,7 +1363,7 @@ func testIPPropagation(t *testing.T, suite *integrationTestSuite) { wg.Wait() } - testNodeConnection := func(t *testing.T, instance *helpers.TeleInstance, clusterName, nodeName string) { + testSSHNodeConnection := func(t *testing.T, instance *helpers.TeleInstance, clusterName, nodeName string) { person := NewTerminal(250) ctx := context.Background() @@ -1355,23 +1377,112 @@ func testIPPropagation(t *testing.T, suite *integrationTestSuite) { tc.Stdout = person tc.Stdin = person - pc, err := tc.ConnectToProxy(ctx) + clt, err := tc.ConnectToProxy(ctx) + require.NoError(t, err) + defer clt.Close() + + nodeClient, err := clt.ConnectToNode( + ctx, + client.NodeDetails{Addr: nodeName, Namespace: tc.Namespace, Cluster: tc.SiteName}, + tc.Config.HostLogin, + sshutils.ClusterDetails{}, + ) + require.NoError(t, err) + defer nodeClient.Close() + + err = nodeClient.RunCommand(ctx, []string{"echo $SSH_CLIENT"}) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return getRemoteAddrString(person.Output(1000)) == clt.Client.LocalAddr().String() + }, time.Millisecond*100, time.Millisecond*10, "client IP:port that node sees doesn't match to real one") + } + + testGRPCNodeConnection := func(t *testing.T, instance *helpers.TeleInstance, clusterName, nodeName string) { + person := NewTerminal(250) + ctx := context.Background() + + tc, err := instance.NewClient(helpers.ClientConfig{ + Login: suite.Me.Username, + Cluster: clusterName, + Host: nodeName, + }) require.NoError(t, err) - defer pc.Close() - tc.Config.MockConnectToProxy = func(ctx context.Context) (*client.ProxyClient, error) { - return pc, nil + tc.Stdout = person + tc.Stdin = person + + local := &localAddr{} + + tc.Config.DialOpts = []grpc.DialOption{ + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + d := net.Dialer{Timeout: defaults.DefaultIOTimeout} + conn, err := d.DialContext(ctx, "tcp", s) + if err != nil { + return nil, trace.Wrap(err) + } + + local.set(conn.LocalAddr()) + return conn, nil + }), } - err = tc.SSH(ctx, []string{"echo $SSH_CLIENT"}, false) + clt, err := tc.ConnectToCluster(ctx) + require.NoError(t, err) + defer clt.Close() + + nodeClient, err := tc.ConnectToNode( + ctx, + clt, + client.NodeDetails{Addr: nodeName, Namespace: tc.Namespace, Cluster: clt.ClusterName()}, + tc.Config.HostLogin, + ) + require.NoError(t, err) + defer nodeClient.Close() + + err = nodeClient.RunCommand(ctx, []string{"echo $SSH_CLIENT"}) require.NoError(t, err) require.Eventually(t, func() bool { - return pc.Client.LocalAddr().String() == getRemoteAddrString(person.Output(1000)) + return getRemoteAddrString(person.Output(1000)) == local.get().String() }, time.Millisecond*100, time.Millisecond*10, "client IP:port that node sees doesn't match to real one") } - testAuthConnection := func(t *testing.T, instance *helpers.TeleInstance, clusterName string) { + testGRPCAuthConnection := func(t *testing.T, instance *helpers.TeleInstance, clusterName string) { + ctx := context.Background() + + tc, err := instance.NewClient(helpers.ClientConfig{ + Login: suite.Me.Username, + Cluster: clusterName, + Host: Host, + }) + require.NoError(t, err) + + local := &localAddr{} + + tc.Config.DialOpts = []grpc.DialOption{ + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + d := net.Dialer{Timeout: defaults.DefaultIOTimeout} + conn, err := d.DialContext(ctx, "tcp", s) + if err != nil { + return nil, trace.Wrap(err) + } + + local.set(conn.LocalAddr()) + return conn, nil + }), + } + + clt, err := tc.ConnectToCluster(ctx) + require.NoError(t, err) + defer clt.Close() + + pingResp, err := clt.AuthClient.Ping(ctx) + require.NoError(t, err) + require.Equal(t, local.get().String(), pingResp.RemoteAddr, "client IP:port that auth server sees doesn't match the real one") + } + + testSSHAuthConnection := func(t *testing.T, instance *helpers.TeleInstance, clusterName string) { ctx := context.Background() tc, err := instance.NewClient(helpers.ClientConfig{ @@ -1381,17 +1492,17 @@ func testIPPropagation(t *testing.T, suite *integrationTestSuite) { }) require.NoError(t, err) - pc, err := tc.ConnectToProxy(ctx) + clt, err := tc.ConnectToProxy(ctx) require.NoError(t, err) - defer pc.Close() + defer clt.Close() - site, err := pc.ConnectToCluster(ctx, clusterName) + site, err := clt.ConnectToCluster(ctx, clusterName) require.NoError(t, err) pingResp, err := site.Ping(ctx) require.NoError(t, err) - expected := pc.Client.LocalAddr().String() + expected := clt.Client.LocalAddr().String() require.Equal(t, expected, pingResp.RemoteAddr, "client IP:port that auth server sees doesn't match the real one") } _, root, leaf := createTrustedClusterPair(t, suite, startNodes) @@ -1407,32 +1518,49 @@ func testIPPropagation(t *testing.T, suite *integrationTestSuite) { testNodeCases := []struct { instance *helpers.TeleInstance clusterName string - nodeNme string + nodeAddr string }{ - {instance: root, clusterName: "root-test", nodeNme: "root-zero"}, - {instance: root, clusterName: "root-test", nodeNme: "root-one"}, - {instance: root, clusterName: "root-test", nodeNme: "root-two"}, - {instance: root, clusterName: "leaf-test", nodeNme: "leaf-zero"}, - {instance: root, clusterName: "leaf-test", nodeNme: "leaf-one"}, - {instance: root, clusterName: "leaf-test", nodeNme: "leaf-two"}, - {instance: leaf, clusterName: "leaf-test", nodeNme: "leaf-zero"}, - {instance: leaf, clusterName: "leaf-test", nodeNme: "leaf-one"}, - {instance: leaf, clusterName: "leaf-test", nodeNme: "leaf-two"}, - } - - for _, test := range testAuthCases { - t.Run(fmt.Sprintf("Auth test source cluster %q -> target cluster %q", - test.instance.Secrets.SiteName, test.clusterName), func(t *testing.T) { - testAuthConnection(t, test.instance, test.clusterName) - }) - } - for _, test := range testNodeCases { - test := test - t.Run(fmt.Sprintf("Node test, node name %q source cluster %q -> target cluster %q", - test.nodeNme, test.instance.Secrets.SiteName, test.clusterName), func(t *testing.T) { - testNodeConnection(t, test.instance, test.clusterName, test.nodeNme) - }) - } + {instance: root, clusterName: "root-test", nodeAddr: "root-zero:0"}, + {instance: root, clusterName: "root-test", nodeAddr: "root-one:0"}, + {instance: root, clusterName: "root-test", nodeAddr: "root-two:0"}, + {instance: root, clusterName: "leaf-test", nodeAddr: "leaf-zero:0"}, + {instance: root, clusterName: "leaf-test", nodeAddr: "leaf-one:0"}, + {instance: root, clusterName: "leaf-test", nodeAddr: "leaf-two:0"}, + {instance: leaf, clusterName: "leaf-test", nodeAddr: "leaf-zero:0"}, + {instance: leaf, clusterName: "leaf-test", nodeAddr: "leaf-one:0"}, + {instance: leaf, clusterName: "leaf-test", nodeAddr: "leaf-two:0"}, + } + + t.Run("Auth Connections", func(t *testing.T) { + for _, test := range testAuthCases { + t.Run(fmt.Sprintf("source cluster=%q target cluster=%q", + test.instance.Secrets.SiteName, test.clusterName), func(t *testing.T) { + t.Run("ssh connection", func(t *testing.T) { + testSSHAuthConnection(t, test.instance, test.clusterName) + }) + + t.Run("grpc connection", func(t *testing.T) { + testGRPCAuthConnection(t, test.instance, test.clusterName) + }) + }) + } + }) + + t.Run("Host Connections", func(t *testing.T) { + for _, test := range testNodeCases { + test := test + t.Run(fmt.Sprintf("target=%q source cluster=%q target cluster=%q", + test.nodeAddr, test.instance.Secrets.SiteName, test.clusterName), func(t *testing.T) { + t.Run("ssh connection", func(t *testing.T) { + testSSHNodeConnection(t, test.instance, test.clusterName, test.nodeAddr) + }) + + t.Run("grpc connection", func(t *testing.T) { + testGRPCNodeConnection(t, test.instance, test.clusterName, test.nodeAddr) + }) + }) + } + }) } // verifySessionJoin covers SSH into shell and joining the same session from another client @@ -1670,6 +1798,7 @@ func errorContains(text string) errorVerifier { } type disconnectTestCase struct { + name string recordingMode string options types.RoleOptions disconnectTimeout time.Duration @@ -1692,6 +1821,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { testCases := []disconnectTestCase{ { + name: "client idle timeout node recoding", recordingMode: types.RecordAtNode, options: types.RoleOptions{ ClientIdleTimeout: types.NewDuration(500 * time.Millisecond), @@ -1699,6 +1829,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { disconnectTimeout: time.Second, }, { + name: "client idle timeout proxy recording", recordingMode: types.RecordAtProxy, options: types.RoleOptions{ ForwardAgent: types.NewBool(true), @@ -1707,6 +1838,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { disconnectTimeout: time.Second, }, { + name: "expired cert node recording", recordingMode: types.RecordAtNode, options: types.RoleOptions{ DisconnectExpiredCert: types.NewBool(true), @@ -1715,6 +1847,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { disconnectTimeout: 4 * time.Second, }, { + name: "expired cert proxy recording", recordingMode: types.RecordAtProxy, options: types.RoleOptions{ ForwardAgent: types.NewBool(true), @@ -1724,7 +1857,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { disconnectTimeout: 4 * time.Second, }, { - // "verify that concurrent connection limits are applied when recording at node", + name: "concurrent connection limits exceeded node recording", recordingMode: types.RecordAtNode, options: types.RoleOptions{ MaxConnections: 1, @@ -1734,7 +1867,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { verifyError: errorContains("administratively prohibited"), }, { - // "verify that concurrent connection limits are applied when recording at proxy", + name: "concurrent connection limits exceeded proxy recording", recordingMode: types.RecordAtProxy, options: types.RoleOptions{ ForwardAgent: types.NewBool(true), @@ -1745,7 +1878,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { verifyError: errorContains("administratively prohibited"), }, { - // "verify that lost connections to auth server terminate controlled conns", + name: "verify that lost connections to auth server terminate controlled connections", recordingMode: types.RecordAtNode, options: types.RoleOptions{ MaxConnections: 1, @@ -1774,7 +1907,7 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) require.Len(t, sems, 1) - timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Second) + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() ss, err := waitForSessionToBeEstablished(timeoutCtx, defaults.Namespace, site) @@ -1785,8 +1918,8 @@ func testDisconnectScenarios(t *testing.T, suite *integrationTestSuite) { }, } - for i, tc := range testCases { - t.Run(fmt.Sprintf("Test %d", i), func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { runDisconnectTest(t, suite, tc) }) } @@ -2005,8 +2138,8 @@ func testInvalidLogins(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) err = tc.SSH(context.Background(), cmd, false) - require.True(t, trace.IsConnectionProblem(err)) - require.Contains(t, err.Error(), `unknown cluster "wrong-site"`) + require.True(t, trace.IsNotFound(err)) + require.Contains(t, err.Error(), `cluster "wrong-site" is not found`) } // TestTwoClustersTunnel creates two teleport clusters: "a" and "b" and creates a @@ -7240,13 +7373,6 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) { require.NoError(t, teleInst.StopAll()) }) - tc, err := teleInst.NewClient(helpers.ClientConfig{ - Login: suite.Me.Username, - Cluster: helpers.Site, - Host: Host, - }) - require.NoError(t, err) - // get OpenSSH CA public key and create host certs ctx := context.Background() authClient := teleInst.Process.GetAuthServer() @@ -7325,18 +7451,32 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) { } require.NoError(t, w.Close()) - // connect to node - proxyClient, err := tc.ConnectToProxy(ctx) + // create client + tc, err := teleInst.NewClient(helpers.ClientConfig{ + Login: suite.Me.Username, + Cluster: helpers.Site, + Host: Host, + }) + require.NoError(t, err) + + // connect to cluster + clt, err := tc.ConnectToCluster(ctx) require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, proxyClient.Close()) + require.NoError(t, clt.Close()) }) - nodeClient, err := tc.ConnectToNode(ctx, proxyClient, client.NodeDetails{ - Addr: sshAddr, - Namespace: tc.Namespace, - Cluster: helpers.Site, - }, tc.Username) + // connect to node + nodeClient, err := tc.ConnectToNode( + ctx, + clt, + client.NodeDetails{ + Addr: sshAddr, + Namespace: tc.Namespace, + Cluster: helpers.Site, + }, + tc.Username, + ) require.NoError(t, err) // forward SSH agent diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index d36326729cdd6..442db7f33e89c 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -1053,12 +1053,14 @@ loop: func testKubeDisconnect(t *testing.T, suite *KubeSuite) { testCases := []disconnectTestCase{ { + name: "idle timeout", options: types.RoleOptions{ ClientIdleTimeout: types.NewDuration(500 * time.Millisecond), }, disconnectTimeout: 2 * time.Second, }, { + name: "expired cert", options: types.RoleOptions{ DisconnectExpiredCert: types.NewBool(true), MaxSessionTTL: types.NewDuration(3 * time.Second), @@ -1066,12 +1068,15 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) { disconnectTimeout: 6 * time.Second, }, } + for i := 0; i < utils.GetIterations(); i++ { - for j, tc := range testCases { - t.Run(fmt.Sprintf("#%02d_iter_%d", j, i), func(t *testing.T) { - runKubeDisconnectTest(t, suite, tc) - }) - } + t.Run(fmt.Sprintf("Iteration=%d", i), func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + runKubeDisconnectTest(t, suite, tc) + }) + } + }) } } diff --git a/lib/agentless/agentless.go b/lib/agentless/agentless.go index baf099121caf8..cb2037027823c 100644 --- a/lib/agentless/agentless.go +++ b/lib/agentless/agentless.go @@ -27,18 +27,48 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/utils/sshutils" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/utils" ) -// SiteClientGetter returns an auth client to a given cluster. -type SiteClientGetter interface { - GetSiteClient(ctx context.Context, clusterName string) (auth.ClientI, error) +// CertGenerator generates certificates from a certificate request. +type CertGenerator interface { + GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCertRequest) (*proto.OpenSSHCert, error) } -// CreateAuthSigner attempts to create a [ssh.Signer] that is signed with +// SignerFromSSHCertificate returns a function that attempts to +// create a [ssh.Signer] for the Identity in the provided [ssh.Certificate] +// that is signed with the OpenSSH CA and can be used to authenticate to agentless nodes. +func SignerFromSSHCertificate(certificate *ssh.Certificate, generator CertGenerator) func(context.Context) (ssh.Signer, error) { + return func(ctx context.Context) (ssh.Signer, error) { + validBefore := time.Unix(int64(certificate.ValidBefore), 0) + ttl := time.Until(validBefore) + + clusterName := certificate.Permissions.Extensions[utils.CertExtensionAuthority] + user := certificate.Permissions.Extensions[utils.CertTeleportUser] + + signer, err := createAuthSigner(ctx, user, clusterName, ttl, generator) + return signer, trace.Wrap(err) + } +} + +// SignerFromAuthzContext returns a function that attempts to +// create a [ssh.Signer] for the [tlsca.Identity] in the provided [authz.Context] +// that is signed with the OpenSSH CA and can be used to authenticate to agentless nodes. +func SignerFromAuthzContext(authzCtx *authz.Context, generator CertGenerator) func(context.Context) (ssh.Signer, error) { + return func(ctx context.Context) (ssh.Signer, error) { + identity := authzCtx.Identity.GetIdentity() + ttl := time.Until(identity.Expires) + + signer, err := createAuthSigner(ctx, authzCtx.User.GetName(), identity.TeleportCluster, ttl, generator) + return signer, trace.Wrap(err) + } +} + +// createAuthSigner creates a [ssh.Signer] that is signed with // OpenSSH CA and can be used to authenticate to agentless nodes. -func CreateAuthSigner(ctx context.Context, username, clusterName string, ttl time.Duration, clientGetter SiteClientGetter) (ssh.Signer, error) { +func createAuthSigner(ctx context.Context, username, clusterName string, ttl time.Duration, generator CertGenerator) (ssh.Signer, error) { // generate a new key pair priv, err := native.GeneratePrivateKey() if err != nil { @@ -46,11 +76,7 @@ func CreateAuthSigner(ctx context.Context, username, clusterName string, ttl tim } // sign new public key with OpenSSH CA - client, err := clientGetter.GetSiteClient(ctx, clusterName) - if err != nil { - return nil, trace.Wrap(err) - } - reply, err := client.GenerateOpenSSHCert(ctx, &proto.OpenSSHCertRequest{ + reply, err := generator.GenerateOpenSSHCert(ctx, &proto.OpenSSHCertRequest{ Username: username, PublicKey: priv.MarshalSSHPublicKey(), TTL: proto.Duration(ttl), diff --git a/lib/client/api.go b/lib/client/api.go index 8ec87bfcab4b5..999b98b467fb5 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -45,10 +45,12 @@ import ( "golang.org/x/crypto/ssh/agent" "golang.org/x/net/http2" "golang.org/x/sync/errgroup" + "google.golang.org/grpc" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + proxyclient "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -365,9 +367,6 @@ type Config struct { // MockSSOLogin is used in tests for mocking the SSO login response. MockSSOLogin SSOLoginFunc - // MockConnectToProxy is used in tests to override connection to proxy - MockConnectToProxy ConnectToProxyFunc - // HomePath is where tsh stores profiles HomePath string @@ -415,6 +414,10 @@ type Config struct { // the ssh, scp, and ls commands can use headless login. Other commands will ignore // headless auth connector and default to local instead. AllowHeadless bool + + // DialOpts used by the api.client.proxy.Client when establishing a connection to + // the proxy server. Used by tests. + DialOpts []grpc.DialOption } // CachePolicy defines cache policy for local clients @@ -1181,7 +1184,7 @@ func (tc *TeleportClient) RootClusterName(ctx context.Context) (string, error) { // getTargetNodes returns a list of node addresses this SSH command needs to // operate on. -func (tc *TeleportClient) getTargetNodes(ctx context.Context, proxy *ProxyClient) ([]string, error) { +func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListResourcesClient) ([]string, error) { ctx, span := tc.Tracer.Start( ctx, "teleportClient/getTargetNodes", @@ -1203,15 +1206,17 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, proxy *ProxyClient } // find the nodes matching the labels that were provided - nodes, err := proxy.FindNodesByFilters(ctx, *tc.DefaultResourceFilter()) + filter := tc.DefaultResourceFilter() + filter.ResourceType = types.KindNode + resources, err := client.GetResourcesWithFilters(ctx, clt, *filter) if err != nil { return nil, trace.Wrap(err) } - retval := make([]string, 0, len(nodes)) - for i := 0; i < len(nodes); i++ { + retval := make([]string, 0, len(resources)) + for _, resource := range resources { // always dial nodes by UUID - retval = append(retval, fmt.Sprintf("%s:0", nodes[i].GetName())) + retval = append(retval, fmt.Sprintf("%s:0", resource.GetName())) } return retval, nil @@ -1427,7 +1432,8 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally "teleportClient/SSH", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes( - attribute.String("proxy", tc.Config.WebProxyAddr), + attribute.String("proxy_web", tc.Config.WebProxyAddr), + attribute.String("proxy_ssh", tc.Config.SSHProxyAddr), ), ) defer span.End() @@ -1436,14 +1442,15 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally if !tc.Config.ProxySpecified() { return trace.BadParameter("proxy server is not specified") } - proxyClient, err := tc.ConnectToProxy(ctx) + + clt, err := tc.ConnectToCluster(ctx) if err != nil { return trace.Wrap(err) } - defer proxyClient.Close() + defer clt.Close() // which nodes are we executing this commands on? - nodeAddrs, err := tc.getTargetNodes(ctx, proxyClient) + nodeAddrs, err := tc.getTargetNodes(ctx, clt.AuthClient) if err != nil { return trace.Wrap(err) } @@ -1452,9 +1459,9 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally } if len(nodeAddrs) > 1 { - return tc.runShellOrCommandOnMultipleNodes(ctx, nodeAddrs, proxyClient, command) + return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command) } - return tc.runShellOrCommandOnSingleNode(ctx, nodeAddrs[0], proxyClient, command, runLocally) + return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], command, runLocally) } // ConnectToNode attempts to establish a connection to the node resolved to by the provided @@ -1463,41 +1470,39 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally // ceremony is performed and another connection is attempted with the freshly minted // certificates. If it is not required, then the original Access Denied error from the node // is returned. -func (tc *TeleportClient) ConnectToNode(ctx context.Context, proxyClient *ProxyClient, nodeDetails NodeDetails, user string) (*NodeClient, error) { +func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (*NodeClient, error) { node := nodeName(nodeDetails.Addr) ctx, span := tc.Tracer.Start( ctx, "teleportClient/ConnectToNode", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes( - attribute.String("site", nodeDetails.Cluster), + attribute.String("cluster", nodeDetails.Cluster), attribute.String("node", node), ), ) defer span.End() - // attempt to use the existing credentials first - authMethods := proxyClient.authMethods + sshConfig := clt.ProxyClient.SSHConfig(user) - // if per-session mfa is required, perform the mfa ceremony to get - // new certificates and use them instead + // if mfa is required generate new config after + // performing the mfa ceremony if nodeDetails.MFACheck != nil && nodeDetails.MFACheck.Required { - am, err := proxyClient.sessionSSHCertificate(ctx, nodeDetails) + cfg, err := clt.SessionSSHConfig(ctx, user, nodeDetails) if err != nil { return nil, trace.Wrap(err) } - authMethods = am + sshConfig = cfg } - // grab the cluster details - details, err := proxyClient.clusterDetails(ctx) + // try connecting to the node + conn, details, err := clt.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, tc.localAgent.ExtendedAgent) if err != nil { return nil, trace.Wrap(err) } - // try connecting to the node - nodeClient, connectErr := proxyClient.ConnectToNode(ctx, nodeDetails, user, details, authMethods) + nodeClient, connectErr := NewNodeClient(ctx, sshConfig, conn, nodeDetails.ProxyFormat(), nodeDetails.Addr, tc, details.FIPS) switch { case connectErr == nil: // no error return client return nodeClient, nil @@ -1508,73 +1513,57 @@ func (tc *TeleportClient) ConnectToNode(ctx context.Context, proxyClient *ProxyC } // access was denied, determine if it was because per-session mfa is required - clt, err := proxyClient.ConnectToCluster(ctx, nodeDetails.Cluster) - if err != nil { - // return the connection error instead of any errors from connecting to auth - return nil, trace.Wrap(connectErr) - } - - check, err := clt.IsMFARequired(ctx, &proto.IsMFARequiredRequest{ + nodeDetails.MFACheck, err = clt.AuthClient.IsMFARequired(ctx, &proto.IsMFARequiredRequest{ Target: &proto.IsMFARequiredRequest_Node{ Node: &proto.NodeLogin{ Node: node, - Login: proxyClient.hostLogin, + Login: tc.HostLogin, }, }, }) if err != nil { + log.Warnf("Unable to determine if session mfa is required: %v", err) return nil, trace.Wrap(connectErr) } // per-session mfa isn't required, the user simply does not // have access to the provided node - if !check.Required { + if !nodeDetails.MFACheck.Required { return nil, trace.Wrap(connectErr) } - // per-session mfa is required, perform the mfa ceremony - key, err := proxyClient.IssueUserCertsWithMFA( - ctx, - ReissueParams{ - NodeName: node, - RouteToCluster: nodeDetails.Cluster, - MFACheck: check, - AuthClient: clt, - }, - func(ctx context.Context, proxyAddr string, c *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return tc.PromptMFAChallenge(ctx, proxyAddr, c, nil /* applyOpts */) - }, - ) + // generate new config after performing the mfa ceremony + cfg, err := clt.SessionSSHConfig(ctx, user, nodeDetails) if err != nil { return nil, trace.Wrap(err) } - // try connecting to the node again with the newly acquired certificates - newAuthMethods, err := key.AsAuthMethod() + conn, details, err = clt.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, tc.localAgent.ExtendedAgent) if err != nil { return nil, trace.Wrap(err) } - nodeClient, err = proxyClient.ConnectToNode(ctx, nodeDetails, user, details, []ssh.AuthMethod{newAuthMethods}) + nodeClient, err = NewNodeClient(ctx, cfg, conn, nodeDetails.ProxyFormat(), nodeDetails.Addr, tc, details.FIPS) return nodeClient, trace.Wrap(err) } -func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, nodeAddr string, proxyClient *ProxyClient, command []string, runLocally bool) error { +func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, runLocally bool) error { + cluster := clt.ClusterName() ctx, span := tc.Tracer.Start( ctx, "teleportClient/runShellOrCommandOnSingleNode", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes( - attribute.String("site", tc.SiteName), attribute.String("node", nodeAddr), + attribute.String("cluster", cluster), ), ) defer span.End() nodeClient, err := tc.ConnectToNode( ctx, - proxyClient, - NodeDetails{Addr: nodeAddr, Namespace: tc.Namespace, Cluster: tc.SiteName}, + clt, + NodeDetails{Addr: nodeAddr, Namespace: tc.Namespace, Cluster: cluster}, tc.Config.HostLogin, ) if err != nil { @@ -1622,18 +1611,19 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, nod if len(command) > 0 { // Reuse the existing nodeClient we connected above. - return tc.runCommand(ctx, nodeClient, command) + return nodeClient.RunCommand(ctx, command) } return tc.runShell(ctx, nodeClient, types.SessionPeerMode, nil, nil) } -func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, nodeAddrs []string, proxyClient *ProxyClient, command []string) error { +func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodeAddrs []string, command []string) error { + cluster := clt.ClusterName() ctx, span := tc.Tracer.Start( ctx, "teleportClient/runShellOrCommandOnMultipleNodes", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes( - attribute.String("site", tc.SiteName), + attribute.String("cluster", cluster), attribute.StringSlice("node", nodeAddrs), ), ) @@ -1642,15 +1632,15 @@ func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, // There was a command provided, run a non-interactive session against each match if len(command) > 0 { fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes matched label selector, running command on all.\n") - return tc.runCommandOnNodes(ctx, tc.SiteName, nodeAddrs, proxyClient, command) + return tc.runCommandOnNodes(ctx, clt, nodeAddrs, command) } // Issue "shell" request to the first matching node. fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes match the label selector, picking first: %q\n", nodeAddrs[0]) nodeClient, err := tc.ConnectToNode( ctx, - proxyClient, - NodeDetails{Addr: nodeAddrs[0], Namespace: tc.Namespace, Cluster: tc.SiteName}, + clt, + NodeDetails{Addr: nodeAddrs[0], Namespace: tc.Namespace, Cluster: cluster}, tc.Config.HostLogin, ) if err != nil { @@ -1706,15 +1696,14 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan if !tc.Config.ProxySpecified() { return trace.BadParameter("proxy server is not specified") } - proxyClient, err := tc.ConnectToProxy(ctx) + clt, err := tc.ConnectToCluster(ctx) if err != nil { return trace.Wrap(err) } - defer proxyClient.Close() - site := proxyClient.CurrentCluster() + defer clt.Close() // Session joining is not supported in proxy recording mode - if recConfig, err := site.GetSessionRecordingConfig(ctx); err != nil { + if recConfig, err := clt.AuthClient.GetSessionRecordingConfig(ctx); err != nil { // If the user can't see the recording mode, just let them try joining below if !trace.IsAccessDenied(err) { return trace.Wrap(err) @@ -1723,7 +1712,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan return trace.BadParameter("session joining is not supported in proxy recording mode") } - session, err := site.GetSessionTracker(ctx, string(sessionID)) + session, err := clt.AuthClient.GetSessionTracker(ctx, string(sessionID)) if err != nil { if trace.IsNotFound(err) { return trace.NotFound("session %q not found or it has ended", sessionID) @@ -1737,8 +1726,8 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan // connect to server: nc, err := tc.ConnectToNode(ctx, - proxyClient, - NodeDetails{Addr: session.GetAddress() + ":0", Namespace: tc.Namespace, Cluster: tc.SiteName}, + clt, + NodeDetails{Addr: session.GetAddress() + ":0", Namespace: tc.Namespace, Cluster: clt.ClusterName()}, tc.Config.HostLogin, ) if err != nil { @@ -1758,7 +1747,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan if mode == types.SessionModeratorMode { beforeStart = func(out io.Writer) { nc.OnMFA = func() { - runPresenceTask(presenceCtx, out, site, tc, session.GetSessionID()) + runPresenceTask(presenceCtx, out, clt.AuthClient, tc, session.GetSessionID()) } } } @@ -1907,17 +1896,17 @@ func (tc *TeleportClient) ExecuteSCP(ctx context.Context, serverAddr string, cmd return trace.BadParameter("proxy server is not specified") } - proxyClient, err := tc.ConnectToProxy(ctx) + clt, err := tc.ConnectToCluster(ctx) if err != nil { return trace.Wrap(err) } - defer proxyClient.Close() + defer clt.Close() nodeClient, err := tc.ConnectToNode( ctx, - proxyClient, + clt, // We append the ":0" to tell the server to figure out the port for us. - NodeDetails{Addr: serverAddr + ":0", Namespace: tc.Namespace, Cluster: tc.SiteName}, + NodeDetails{Addr: serverAddr + ":0", Namespace: tc.Namespace, Cluster: clt.ClusterName()}, tc.Config.HostLogin, ) if err != nil { @@ -2063,16 +2052,16 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr if !tc.Config.ProxySpecified() { return trace.BadParameter("proxy server is not specified") } - proxyClient, err := tc.ConnectToProxy(ctx) + clt, err := tc.ConnectToCluster(ctx) if err != nil { return trace.Wrap(err) } - defer proxyClient.Close() + defer clt.Close() client, err := tc.ConnectToNode( ctx, - proxyClient, - NodeDetails{Addr: nodeAddr, Namespace: tc.Namespace, Cluster: tc.SiteName}, + clt, + NodeDetails{Addr: nodeAddr, Namespace: tc.Namespace, Cluster: clt.ClusterName()}, hostLogin, ) if err != nil { @@ -2102,7 +2091,9 @@ func (tc *TeleportClient) ListNodesWithFilters(ctx context.Context) ([]types.Ser } defer proxyClient.Close() - servers, err := proxyClient.FindNodesByFilters(ctx, *tc.DefaultResourceFilter()) + filter := tc.DefaultResourceFilter() + filter.ResourceType = types.KindNode + servers, err := proxyClient.FindNodesByFilters(ctx, *filter) if err != nil { return nil, trace.Wrap(err) } @@ -2491,28 +2482,26 @@ func commandLimit(ctx context.Context, getter roleGetter, mfaRequired bool) int } // runCommandOnNodes executes a given bash command on a bunch of remote nodes. -func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, siteName string, nodeAddresses []string, proxyClient *ProxyClient, command []string) error { +func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterClient, nodeAddresses []string, command []string) error { + cluster := clt.ClusterName() ctx, span := tc.Tracer.Start( ctx, "teleportClient/runCommandOnNodes", oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("cluster", cluster), + ), ) defer span.End() - clt, err := proxyClient.ConnectToCluster(ctx, siteName) - if err != nil { - return trace.Wrap(err) - } - defer clt.Close() - // Let's check if the first node requires mfa. // If it's required, run commands sequentially to avoid // race conditions and weird ux during mfa. - mfaRequiredCheck, err := clt.IsMFARequired(ctx, &proto.IsMFARequiredRequest{ + mfaRequiredCheck, err := clt.AuthClient.IsMFARequired(ctx, &proto.IsMFARequiredRequest{ Target: &proto.IsMFARequiredRequest_Node{ Node: &proto.NodeLogin{ Node: nodeName(nodeAddresses[0]), - Login: proxyClient.hostLogin, + Login: tc.Config.HostLogin, }, }, }) @@ -2521,7 +2510,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, siteName string } g, gctx := errgroup.WithContext(ctx) - g.SetLimit(commandLimit(ctx, clt, mfaRequiredCheck.Required)) + g.SetLimit(commandLimit(ctx, clt.AuthClient, mfaRequiredCheck.Required)) for _, address := range nodeAddresses { address := address g.Go(func() error { @@ -2535,11 +2524,11 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, siteName string nodeClient, err := tc.ConnectToNode( ctx, - proxyClient, + clt, NodeDetails{ Addr: address, Namespace: tc.Namespace, - Cluster: siteName, + Cluster: cluster, MFACheck: mfaRequiredCheck, }, tc.Config.HostLogin, @@ -2552,47 +2541,13 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, siteName string fmt.Printf("Running command on %v:\n", nodeName(address)) - return trace.Wrap(tc.runCommand(ctx, nodeClient, command)) + return trace.Wrap(nodeClient.RunCommand(ctx, command)) }) } return trace.Wrap(g.Wait()) } -// runCommand executes a given bash command on an established NodeClient. -func (tc *TeleportClient) runCommand(ctx context.Context, nodeClient *NodeClient, command []string) error { - ctx, span := tc.Tracer.Start( - ctx, - "teleportClient/runCommand", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - ) - defer span.End() - - nodeSession, err := newSession(ctx, nodeClient, nil, tc.newSessionEnv(), tc.Stdin, tc.Stdout, tc.Stderr, tc.EnableEscapeSequences) - if err != nil { - return trace.Wrap(err) - } - defer nodeSession.Close() - if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, tc.OnShellCreated, tc.Config.InteractiveCommand); err != nil { - originErr := trace.Unwrap(err) - exitErr, ok := originErr.(*ssh.ExitError) - if ok { - tc.ExitStatus = exitErr.ExitStatus() - } else { - // if an error occurs, but no exit status is passed back, GoSSH returns - // a generic error like this. in this case the error message is printed - // to stderr by the remote process so we have to quietly return 1: - if strings.Contains(originErr.Error(), "exited without exit status") { - tc.ExitStatus = 1 - } - } - - return trace.Wrap(err) - } - - return nil -} - func (tc *TeleportClient) newSessionEnv() map[string]string { env := map[string]string{ teleport.SSHSessionWebproxyAddr: tc.WebProxyAddr, @@ -2702,52 +2657,72 @@ func formatConnectToProxyErr(err error) error { return err } -// ConnectToProxyFunc is used in tests to override connection to proxy function. -type ConnectToProxyFunc func(ctx context.Context) (*ProxyClient, error) - -// ConnectToProxy will dial to the proxy server and return a ProxyClient when -// successful. If the passed in context is canceled, this function will return -// a trace.ConnectionProblem right away. -func (tc *TeleportClient) ConnectToProxy(ctx context.Context) (*ProxyClient, error) { - if tc.Config.MockConnectToProxy != nil { - return tc.Config.MockConnectToProxy(ctx) - } - +// ConnectToCluster will dial the auth and proxy server and return a ClusterClient when +// successful. +func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, error) { ctx, span := tc.Tracer.Start( ctx, - "teleportClient/ConnectToProxy", + "teleportClient/ConnectToCluster", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes( - attribute.String("proxy", tc.Config.WebProxyAddr), + attribute.String("proxy_web", tc.Config.WebProxyAddr), + attribute.String("proxy_ssh", tc.Config.SSHProxyAddr), ), ) defer span.End() - var err error - var proxyClient *ProxyClient + cfg, err := tc.generateClientConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) + } - // Use connectContext and the cancel function to signal when a response is - // returned from connectToProxy. - connectContext, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - proxyClient, err = tc.connectToProxy(ctx) - }() + tlsConfig, err := tc.LoadTLSConfig() + if err != nil { + return nil, trace.Wrap(err) + } - select { - // ConnectToProxy returned a result, return that back to the caller. - case <-connectContext.Done(): - return proxyClient, trace.Wrap(formatConnectToProxyErr(err)) - // The passed in context timed out. This is often due to the network being - // down and the user hitting Ctrl-C. - case <-ctx.Done(): - return nil, trace.ConnectionProblem(ctx.Err(), "connection canceled") + pclt, err := proxyclient.NewClient(ctx, proxyclient.ClientConfig{ + ProxyWebAddress: tc.WebProxyAddr, + ProxySSHAddress: cfg.proxyAddress, + TLSRoutingEnabled: tc.TLSRoutingEnabled, + TLSConfig: tlsConfig, + DialOpts: tc.Config.DialOpts, + UnaryInterceptors: []grpc.UnaryClientInterceptor{utils.GRPCClientUnaryErrorInterceptor}, + StreamInterceptors: []grpc.StreamClientInterceptor{utils.GRPCClientStreamErrorInterceptor}, + SSHDialer: proxyclient.SSHDialerFunc(func(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { + clt, err := makeProxySSHClient(ctx, tc, config) + return clt, trace.Wrap(err) + }), + SSHConfig: cfg.ClientConfig, + }) + if err != nil { + return nil, trace.Wrap(err) } + + aclt, err := auth.NewClient(pclt.ClientConfig(ctx, pclt.ClusterName())) + if err != nil { + return nil, trace.NewAggregate(err, pclt.Close()) + } + + return &ClusterClient{ + tc: tc, + ProxyClient: pclt, + AuthClient: aclt, + Tracer: tc.Tracer, + }, nil } -// connectToProxy will dial to the proxy server and return a ProxyClient when -// successful. -func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, error) { +// clientConfig wraps ssh.ClientConfig with additional +// information about a cluster. +type clientConfig struct { + *ssh.ClientConfig + proxyAddress string + clusterName func() string +} + +// generateClientConfig returns clientConfig that can be used to establish a +// connection to a cluster. +func (tc *TeleportClient) generateClientConfig(ctx context.Context) (*clientConfig, error) { sshProxyAddr := tc.Config.SSHProxyAddr hostKeyCallback := tc.HostKeyCallback @@ -2800,13 +2775,64 @@ func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, err return nil, trace.BadParameter("no SSH auth methods loaded, are you logged in?") } - sshConfig := &ssh.ClientConfig{ - User: tc.getProxySSHPrincipal(), - HostKeyCallback: hostKeyCallback, - Auth: authMethods, + return &clientConfig{ + ClientConfig: &ssh.ClientConfig{ + User: tc.getProxySSHPrincipal(), + HostKeyCallback: hostKeyCallback, + Auth: authMethods, + Timeout: apidefaults.DefaultIOTimeout, + }, + proxyAddress: sshProxyAddr, + clusterName: clusterName, + }, nil +} + +// ConnectToProxy will dial to the proxy server and return a ProxyClient when +// successful. If the passed in context is canceled, this function will return +// a trace.ConnectionProblem right away. +func (tc *TeleportClient) ConnectToProxy(ctx context.Context) (*ProxyClient, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/ConnectToProxy", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("proxy_web", tc.Config.WebProxyAddr), + attribute.String("proxy_ssh", tc.Config.SSHProxyAddr), + ), + ) + defer span.End() + + var err error + var proxyClient *ProxyClient + + // Use connectContext and the cancel function to signal when a response is + // returned from connectToProxy. + connectContext, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + proxyClient, err = tc.connectToProxy(connectContext) + }() + + select { + // connectToProxy returned a result, return that back to the caller. + case <-connectContext.Done(): + return proxyClient, trace.Wrap(formatConnectToProxyErr(err)) + // The passed in context timed out. This is often due to the network being + // down and the user hitting Ctrl-C. + case <-ctx.Done(): + return nil, trace.ConnectionProblem(ctx.Err(), "connection canceled") + } +} + +// connectToProxy will dial to the proxy server and return a ProxyClient when +// successful. +func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, error) { + cfg, err := tc.generateClientConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) } - sshClient, err := makeProxySSHClient(ctx, tc, sshConfig) + sshClient, err := makeProxySSHClient(ctx, tc, cfg.ClientConfig) if err != nil { return nil, trace.Wrap(err) } @@ -2814,12 +2840,12 @@ func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, err pc := &ProxyClient{ teleportClient: tc, Client: sshClient, - proxyAddress: sshProxyAddr, - proxyPrincipal: sshConfig.User, - hostKeyCallback: sshConfig.HostKeyCallback, - authMethods: sshConfig.Auth, + proxyAddress: cfg.proxyAddress, + proxyPrincipal: cfg.User, + hostKeyCallback: cfg.HostKeyCallback, + authMethods: cfg.Auth, hostLogin: tc.HostLogin, - siteName: clusterName(), + siteName: cfg.clusterName(), clientAddr: tc.ClientAddr, Tracer: tc.Tracer, } @@ -4503,7 +4529,8 @@ func (tc *TeleportClient) HeadlessApprove(ctx context.Context, headlessAuthentic "teleportClient/HeadlessApprove", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes( - attribute.String("proxy", tc.Config.WebProxyAddr), + attribute.String("proxy_web", tc.Config.WebProxyAddr), + attribute.String("proxy_ssh", tc.Config.SSHProxyAddr), ), ) defer span.End() diff --git a/lib/client/client.go b/lib/client/client.go index 7edf06e797055..cc1e6941858ae 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1258,29 +1258,6 @@ func nodeName(node string) string { return n } -// clusterDetails retrieves information about the current cluster needed to connect to nodes. -func (proxy *ProxyClient) clusterDetails(ctx context.Context) (sshutils.ClusterDetails, error) { - ctx, span := proxy.Tracer.Start( - ctx, - "proxyClient/clusterDetails", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - ) - defer span.End() - - var details sshutils.ClusterDetails - ok, resp, err := proxy.Client.SendRequest(ctx, teleport.ClusterDetailsReqType, true, nil) - if err != nil { - return details, trace.Wrap(err) - } - - if ok { - err = ssh.Unmarshal(resp, &details) - return details, trace.Wrap(err) - } - - return details, trace.BadParameter("failed to get cluster details") -} - // dialAuthServer returns auth server connection forwarded via proxy func (proxy *ProxyClient) dialAuthServer(ctx context.Context, clusterName string) (net.Conn, error) { ctx, span := proxy.Tracer.Start( @@ -1401,7 +1378,7 @@ func requestSubsystem(ctx context.Context, session *tracessh.Session, name strin // ConnectToNode connects to the ssh server via Proxy. // It returns connected and authenticated NodeClient -func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDetails, user string, details sshutils.ClusterDetails, authMethods []ssh.AuthMethod) (*NodeClient, error) { +func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDetails, user string, details sshutils.ClusterDetails) (*NodeClient, error) { ctx, span := proxy.Tracer.Start( ctx, "proxyClient/ConnectToNode", @@ -1416,7 +1393,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDet log.Infof("Client=%v connecting to node=%v", proxy.clientAddr, nodeAddress) if len(proxy.teleportClient.JumpHosts) > 0 { - return proxy.PortForwardToNode(ctx, nodeAddress, user, details, authMethods) + return proxy.PortForwardToNode(ctx, nodeAddress, user, details, proxy.authMethods) } // parse destination first: @@ -1498,7 +1475,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDet sshConfig := &ssh.ClientConfig{ User: user, - Auth: authMethods, + Auth: proxy.authMethods, HostKeyCallback: proxy.hostKeyCallback, } @@ -1651,6 +1628,40 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session return nil } +// RunCommand executes a given bash command on the node. +func (c *NodeClient) RunCommand(ctx context.Context, command []string) error { + ctx, span := c.Tracer.Start( + ctx, + "nodeClient/RunCommand", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + nodeSession, err := newSession(ctx, c, nil, c.TC.newSessionEnv(), c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, c.TC.EnableEscapeSequences) + if err != nil { + return trace.Wrap(err) + } + defer nodeSession.Close() + if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil { + originErr := trace.Unwrap(err) + exitErr, ok := originErr.(*ssh.ExitError) + if ok { + c.TC.ExitStatus = exitErr.ExitStatus() + } else { + // if an error occurs, but no exit status is passed back, GoSSH returns + // a generic error like this. in this case the error message is printed + // to stderr by the remote process so we have to quietly return 1: + if strings.Contains(originErr.Error(), "exited without exit status") { + c.TC.ExitStatus = 1 + } + } + + return trace.Wrap(err) + } + + return nil +} + func (c *NodeClient) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.Request) { for { select { @@ -2029,37 +2040,6 @@ func (c *NodeClient) Close() error { return c.Client.Close() } -func (proxy *ProxyClient) sessionSSHCertificate(ctx context.Context, nodeAddr NodeDetails) ([]ssh.AuthMethod, error) { - if _, err := proxy.teleportClient.localAgent.GetKey(nodeAddr.Cluster); err != nil { - if trace.IsNotFound(err) { - // Either running inside the web UI in a proxy or using an identity - // file. Fall back to whatever AuthMethod we currently have. - return proxy.authMethods, nil - } - return nil, trace.Wrap(err) - } - - key, err := proxy.IssueUserCertsWithMFA( - ctx, - ReissueParams{ - NodeName: nodeName(nodeAddr.Addr), - RouteToCluster: proxy.ClusterName(), - MFACheck: nodeAddr.MFACheck, - }, - func(ctx context.Context, proxyAddr string, c *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return proxy.teleportClient.PromptMFAChallenge(ctx, proxyAddr, c, nil /* applyOpts */) - }, - ) - if err != nil { - return nil, trace.Wrap(err) - } - am, err := key.AsAuthMethod() - if err != nil { - return nil, trace.Wrap(err) - } - return []ssh.AuthMethod{am}, nil -} - // localAgent returns for the Teleport client's local agent. func (proxy *ProxyClient) localAgent() *LocalKeyAgent { return proxy.teleportClient.LocalAgent() diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go new file mode 100644 index 0000000000000..04d952ae1bb27 --- /dev/null +++ b/lib/client/cluster_client.go @@ -0,0 +1,365 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + + "github.com/gravitational/trace" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/client/proto" + proxyclient "github.com/gravitational/teleport/api/client/proxy" + "github.com/gravitational/teleport/lib/auth" +) + +// ClusterClient facilitates communicating with both the +// Auth and Proxy services of a cluster. +type ClusterClient struct { + tc *TeleportClient + ProxyClient *proxyclient.Client + AuthClient auth.ClientI + Tracer oteltrace.Tracer +} + +// ClusterName returns the name of the cluster that the client +// is connected to. +func (c *ClusterClient) ClusterName() string { + cluster := c.ProxyClient.ClusterName() + if len(c.tc.JumpHosts) > 0 && cluster != "" { + return cluster + } + + return c.tc.SiteName +} + +// Close terminates the connections to Auth and Proxy. +func (c *ClusterClient) Close() error { + // close auth client first since it is tunneled through the proxy client + return trace.NewAggregate(c.AuthClient.Close(), c.ProxyClient.Close()) +} + +// SessionSSHConfig returns the [ssh.ClientConfig] that should be used to connected to the +// provided target for the provided user. If per session MFA is required to establish the +// connection, then the MFA ceremony will be performed. +func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, target NodeDetails) (*ssh.ClientConfig, error) { + ctx, span := c.Tracer.Start( + ctx, + "clusterClient/SessionSSHConfig", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("cluster", c.tc.SiteName), + ), + ) + defer span.End() + + sshConfig := c.ProxyClient.SSHConfig(user) + + if target.MFACheck != nil && !target.MFACheck.Required { + return sshConfig, nil + } + + key, err := c.tc.localAgent.GetKey(target.Cluster, WithAllCerts...) + if err != nil { + if trace.IsNotFound(err) { + // Either running inside the web UI in a proxy or using an identity + // file. Fall back to whatever AuthMethod we currently have. + return sshConfig, nil + } + return nil, trace.Wrap(err) + } + + params := ReissueParams{ + NodeName: nodeName(target.Addr), + RouteToCluster: target.Cluster, + MFACheck: target.MFACheck, + } + + // requiredCheck passed from param can be nil. + if target.MFACheck == nil { + check, err := c.AuthClient.IsMFARequired(ctx, params.isMFARequiredRequest(c.tc.HostLogin)) + if err != nil { + return nil, trace.Wrap(err) + } + target.MFACheck = check + } + + if !target.MFACheck.Required { + log.Debug("MFA not required for access.") + // MFA is not required. + // SSH certs can be used without embedding the node name. + if params.usage() == proto.UserCertsRequest_SSH { + return sshConfig, nil + } + + // All other targets need their name embedded in the cert for routing, + // fall back to non-MFA reissue. + key, err := c.reissueUserCerts(ctx, CertCacheKeep, params) + if err != nil { + return nil, trace.Wrap(err) + } + + am, err := key.AsAuthMethod() + if err != nil { + return nil, trace.Wrap(err) + } + + sshConfig.Auth = []ssh.AuthMethod{am} + return sshConfig, nil + } + + // Always connect to root for getting new credentials, but attempt to reuse + // the existing client if possible. + rootClusterName, err := key.RootClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + + mfaClt := c + if params.RouteToCluster != rootClusterName { + jumpHosts := c.tc.JumpHosts + // In case of MFA connect to root teleport proxy instead of JumpHost to request + // MFA certificates. + c.tc.JumpHosts = nil + clt, err := c.tc.ConnectToCluster(ctx) + c.tc.JumpHosts = jumpHosts + if err != nil { + return nil, trace.Wrap(err) + } + + mfaClt = clt + defer clt.Close() + } + + log.Debug("Attempting to issue a single-use user certificate with an MFA check.") + key, err = performMFACeremony(ctx, mfaClt, params, key) + if err != nil { + return nil, trace.Wrap(err) + } + + log.Debug("Issued single-use user certificate after an MFA check.") + am, err := key.AsAuthMethod() + if err != nil { + return nil, trace.Wrap(err) + } + + sshConfig.Auth = []ssh.AuthMethod{am} + return sshConfig, nil +} + +// reissueUserCerts gets new user certificates from the root Auth server. +func (c *ClusterClient) reissueUserCerts(ctx context.Context, cachePolicy CertCachePolicy, params ReissueParams) (*Key, error) { + if params.RouteToCluster == "" { + params.RouteToCluster = c.tc.SiteName + } + key := params.ExistingCreds + if key == nil { + var err error + + // Don't load the certs if we're going to drop all of them all as part + // of the re-issue. If we load all of the old certs now we won't be able + // to differentiate between legacy certificates (that need to be + // deleted) and newly re-issued certs (that we definitely do *not* want + // to delete) when it comes time to drop them from the local agent. + var certOptions []CertOption + if cachePolicy == CertCacheKeep { + certOptions = WithAllCerts + } + + key, err = c.tc.localAgent.GetKey(params.RouteToCluster, certOptions...) + if err != nil { + return nil, trace.Wrap(err) + } + } + + req, err := c.prepareUserCertsRequest(params, key) + if err != nil { + return nil, trace.Wrap(err) + } + + root, err := c.tc.rootClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + + clt, err := auth.NewClient(c.ProxyClient.ClientConfig(ctx, root)) + if err != nil { + return nil, trace.Wrap(err) + } + defer clt.Close() + + certs, err := clt.GenerateUserCerts(ctx, *req) + if err != nil { + return nil, trace.Wrap(err) + } + + key.ClusterName = params.RouteToCluster + + // Only update the parts of key that match the usage. See the docs on + // proto.UserCertsRequest_CertUsage for which certificates match which + // usage. + // + // This prevents us from overwriting the top-level key.TLSCert with + // usage-restricted certificates. + switch params.usage() { + case proto.UserCertsRequest_All: + key.Cert = certs.SSH + key.TLSCert = certs.TLS + case proto.UserCertsRequest_SSH: + key.Cert = certs.SSH + case proto.UserCertsRequest_App: + key.AppTLSCerts[params.RouteToApp.Name] = certs.TLS + case proto.UserCertsRequest_Database: + dbCert, err := makeDatabaseClientPEM(params.RouteToDatabase.Protocol, certs.TLS, key) + if err != nil { + return nil, trace.Wrap(err) + } + key.DBTLSCerts[params.RouteToDatabase.ServiceName] = dbCert + case proto.UserCertsRequest_Kubernetes: + key.KubeTLSCerts[params.KubernetesCluster] = certs.TLS + case proto.UserCertsRequest_WindowsDesktop: + key.WindowsDesktopCerts[params.RouteToWindowsDesktop.WindowsDesktop] = certs.TLS + } + return key, nil +} + +// prepareUserCertsRequest creates a [proto.UserCertsRequest] with the fields +// set accordingly from the provided ReissueParams. +func (c *ClusterClient) prepareUserCertsRequest(params ReissueParams, key *Key) (*proto.UserCertsRequest, error) { + tlsCert, err := key.TeleportTLSCertificate() + if err != nil { + return nil, trace.Wrap(err) + } + + if len(params.AccessRequests) == 0 { + // Get the active access requests to include in the cert. + activeRequests, err := key.ActiveRequests() + // key.ActiveRequests can return a NotFound error if it doesn't have an + // SSH cert. That's OK, we just assume that there are no AccessRequests + // in that case. + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + params.AccessRequests = activeRequests.AccessRequests + } + + return &proto.UserCertsRequest{ + PublicKey: key.MarshalSSHPublicKey(), + Username: tlsCert.Subject.CommonName, + Expires: tlsCert.NotAfter, + RouteToCluster: params.RouteToCluster, + KubernetesCluster: params.KubernetesCluster, + AccessRequests: params.AccessRequests, + DropAccessRequests: params.DropAccessRequests, + RouteToDatabase: params.RouteToDatabase, + RouteToWindowsDesktop: params.RouteToWindowsDesktop, + RouteToApp: params.RouteToApp, + NodeName: params.NodeName, + Usage: params.usage(), + Format: c.tc.CertificateFormat, + RequesterName: params.RequesterName, + }, nil +} + +// performMFACeremony runs the mfa ceremony to completion. If successful the returned +// [Key] will be authorized to connect to the target. +func performMFACeremony(ctx context.Context, clt *ClusterClient, params ReissueParams, key *Key) (*Key, error) { + stream, err := clt.AuthClient.GenerateUserSingleUseCerts(ctx) + if err != nil { + if trace.IsNotImplemented(err) { + // Probably talking to an older server, use the old non-MFA endpoint. + log.WithError(err).Debug("Auth server does not implement GenerateUserSingleUseCerts.") + // SSH certs can be used without reissuing. + if params.usage() == proto.UserCertsRequest_SSH && key.Cert != nil { + return key, nil + } + + key, err := clt.reissueUserCerts(ctx, CertCacheKeep, params) + return key, trace.Wrap(err) + } + return nil, trace.Wrap(err) + } + defer func() { + // CloseSend closes the client side of the stream + stream.CloseSend() + // Recv to wait for the server side of the stream to end, this needs to + // be called to ensure the spans are finished properly + stream.Recv() + }() + + initReq, err := clt.prepareUserCertsRequest(params, key) + if err != nil { + return nil, trace.Wrap(err) + } + + err = stream.Send(&proto.UserSingleUseCertsRequest{Request: &proto.UserSingleUseCertsRequest_Init{ + Init: initReq, + }}) + if err != nil { + return nil, trace.Wrap(err) + } + + resp, err := stream.Recv() + if err != nil { + return nil, trace.Wrap(err) + } + mfaChal := resp.GetMFAChallenge() + if mfaChal == nil { + return nil, trace.BadParameter("server sent a %T on GenerateUserSingleUseCerts, expected MFAChallenge", resp.Response) + } + mfaResp, err := clt.tc.PromptMFAChallenge(ctx, clt.tc.WebProxyAddr, mfaChal, nil /* applyOpts */) + if err != nil { + return nil, trace.Wrap(err) + } + err = stream.Send(&proto.UserSingleUseCertsRequest{Request: &proto.UserSingleUseCertsRequest_MFAResponse{MFAResponse: mfaResp}}) + if err != nil { + return nil, trace.Wrap(err) + } + + resp, err = stream.Recv() + if err != nil { + return nil, trace.Wrap(err) + } + certResp := resp.GetCert() + if certResp == nil { + return nil, trace.BadParameter("server sent a %T on GenerateUserSingleUseCerts, expected SingleUseUserCert", resp.Response) + } + switch crt := certResp.Cert.(type) { + case *proto.SingleUseUserCert_SSH: + key.Cert = crt.SSH + case *proto.SingleUseUserCert_TLS: + switch initReq.Usage { + case proto.UserCertsRequest_Kubernetes: + key.KubeTLSCerts[initReq.KubernetesCluster] = crt.TLS + case proto.UserCertsRequest_Database: + dbCert, err := makeDatabaseClientPEM(params.RouteToDatabase.Protocol, crt.TLS, key) + if err != nil { + return nil, trace.Wrap(err) + } + key.DBTLSCerts[params.RouteToDatabase.ServiceName] = dbCert + case proto.UserCertsRequest_WindowsDesktop: + key.WindowsDesktopCerts[params.RouteToWindowsDesktop.WindowsDesktop] = crt.TLS + default: + return nil, trace.BadParameter("server returned a TLS certificate but cert request usage was %s", initReq.Usage) + } + default: + return nil, trace.BadParameter("server sent a %T SingleUseUserCert in response", certResp.Cert) + } + key.ClusterName = params.RouteToCluster + + return key, nil +} diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 3427bd95d483e..7be0099c47954 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -191,23 +191,19 @@ func NewRouter(cfg RouterConfig) (*Router, error) { }, nil } -// SignerCreator allows the caller to configure a [ssh.Signer] if a ssh -// user agent isn't available, ie when connecting to agentless nodes. -type SignerCreator func() (ssh.Signer, error) - // DialHost dials the node that matches the provided host, port and cluster. If no matching node // is found an error is returned. If more than one matching node is found and the cluster networking // configuration is not set to route to the most recent an error is returned. Also returns teleport version of the // target server if it's a teleport server // DELETE IN 14.0: remove returning teleport version, it was needed for compatibility -func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter, singerCreator SignerCreator) (_ net.Conn, teleportVersion string, err error) { +func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter, signer func(context.Context) (ssh.Signer, error)) (_ net.Conn, teleportVersion string, err error) { ctx, span := r.tracer.Start( ctx, "router/DialHost", oteltrace.WithAttributes( attribute.String("host", host), attribute.String("port", port), - attribute.String("site", clusterName), + attribute.String("cluster", clusterName), ), ) defer func() { @@ -279,9 +275,9 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. // if the node is a registered openssh node, create a signer for auth // and don't set agentGetter so a SSH user agent will not be created // when connecting to the remote node - var signer ssh.Signer + var sshSigner ssh.Signer if isAgentlessNode { - signer, err = singerCreator() + sshSigner, err = signer(ctx) if err != nil { return nil, "", trace.Wrap(err) } @@ -293,7 +289,7 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr}, OriginalClientDstAddr: clientDstAddr, GetUserAgent: agentGetter, - AgentlessSigner: signer, + AgentlessSigner: sshSigner, Address: host, Principals: principals, ServerID: serverID, @@ -316,7 +312,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check ctx, "router/getRemoteCluster", oteltrace.WithAttributes( - attribute.String("site", clusterName), + attribute.String("cluster", clusterName), ), ) defer span.End() @@ -458,7 +454,7 @@ func (r *Router) DialSite(ctx context.Context, clusterName string, clientSrcAddr ctx, "router/DialSite", oteltrace.WithAttributes( - attribute.String("site", clusterName), + attribute.String("cluster", clusterName), ), ) defer span.End() diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index 70465a32ef85c..c5102fac7e32e 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -360,7 +360,7 @@ func TestRouter_DialHost(t *testing.T) { agentGetter := func() (teleagent.Agent, error) { return nil, nil } - createSigner := func() (ssh.Signer, error) { + createSigner := func(context.Context) (ssh.Signer, error) { key, err := native.GeneratePrivateKey() if err != nil { return nil, err diff --git a/lib/service/service.go b/lib/service/service.go index 975f2554bbb87..2bf330c6ec90d 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -62,9 +62,12 @@ import ( "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" kubeproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/kube/v1" + transportpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib" + "github.com/gravitational/teleport/lib/agentless" "github.com/gravitational/teleport/lib/auditd" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/keygen" @@ -116,6 +119,7 @@ import ( "github.com/gravitational/teleport/lib/srv/desktop" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/srv/regular" + "github.com/gravitational/teleport/lib/srv/transport/transportv1" "github.com/gravitational/teleport/lib/system" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/cert" @@ -3844,11 +3848,16 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { ClusterName: clusterName, } + tlscfg := serverTLSConfig.Clone() + tlscfg.ClientAuth = tls.RequireAndVerifyClientCert + if lib.IsInsecureDevMode() { + tlscfg.InsecureSkipVerify = true + tlscfg.ClientAuth = tls.RequireAnyClientCert + } creds, err := auth.NewTransportCredentials(auth.TransportCredentialsConfig{ - TransportCredentials: credentials.NewTLS(serverTLSConfig), + TransportCredentials: credentials.NewTLS(tlscfg), UserGetter: authMiddleware, Authorizer: authorizer, - Enforcer: sessionController, }) if err != nil { return trace.Wrap(err) @@ -3866,6 +3875,33 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { grpc.Creds(creds), ) + connMonitor, err := srv.NewConnectionMonitor(srv.ConnectionMonitorConfig{ + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + Clock: process.Clock, + ServerID: serverID, + Emitter: streamEmitter, + Logger: process.log, + }) + if err != nil { + return trace.Wrap(err) + } + + transportService, err := transportv1.NewService(transportv1.ServerConfig{ + FIPS: cfg.FIPS, + Logger: process.log.WithField(trace.Component, "transport"), + Dialer: proxyRouter, + SignerFn: func(authzCtx *authz.Context) func(context.Context) (ssh.Signer, error) { + return agentless.SignerFromAuthzContext(authzCtx, conn.Client) + }, + ConnectionMonitor: connMonitor, + LocalAddr: listeners.sshGRPC.Addr(), + }) + if err != nil { + return trace.Wrap(err) + } + transportpb.RegisterTransportServiceServer(sshGRPCServer, transportService) + process.RegisterCriticalFunc("proxy.ssh", func() error { utils.Consolef(cfg.Console, log, teleport.ComponentProxy, "SSH proxy service %s:%s is starting on %v.", teleport.Version, teleport.Gitref, cfg.Proxy.SSHAddr.Addr) diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index fad37e5c90bcb..65bbf9e3d9bc5 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -24,7 +24,6 @@ import ( "io" "net" "strings" - "time" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -261,12 +260,14 @@ func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, clientSrc t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v", t.host, t.port, t.SpecifiedPort()) aGetter := t.ctx.StartAgentChannel - signerCreator := func() (ssh.Signer, error) { - validBefore := time.Unix(int64(t.ctx.Identity.Certificate.ValidBefore), 0) - ttl := time.Until(validBefore) - return agentless.CreateAuthSigner(ctx, t.ctx.Identity.TeleportUser, t.clusterName, ttl, t.router) + + client, err := t.router.GetSiteClient(ctx, t.clusterName) + if err != nil { + return trace.Wrap(err) } - conn, teleportVersion, err := t.router.DialHost(ctx, clientSrcAddr, clientDstAddr, t.host, t.port, t.clusterName, t.ctx.Identity.AccessChecker, aGetter, signerCreator) + + signer := agentless.SignerFromSSHCertificate(t.ctx.Identity.Certificate, client) + conn, teleportVersion, err := t.router.DialHost(ctx, clientSrcAddr, clientDstAddr, t.host, t.port, t.clusterName, t.ctx.Identity.AccessChecker, aGetter, signer) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/transport/transportv1/transport.go b/lib/srv/transport/transportv1/transport.go index abb073b7d6e8c..a9eb17e524341 100644 --- a/lib/srv/transport/transportv1/transport.go +++ b/lib/srv/transport/transportv1/transport.go @@ -22,6 +22,7 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "google.golang.org/grpc/credentials" "google.golang.org/grpc/peer" @@ -29,6 +30,7 @@ import ( transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" streamutils "github.com/gravitational/teleport/api/utils/grpc/stream" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" @@ -36,24 +38,38 @@ import ( // Dialer is the interface that groups basic dialing methods. type Dialer interface { - DialSite(ctx context.Context, clusterName string) (net.Conn, error) - DialHost(ctx context.Context, from net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter) (net.Conn, error) + DialSite(ctx context.Context, cluster string, clientSrcAddr, clientDstAddr net.Addr) (net.Conn, error) + DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.Addr, host, port, cluster string, checker services.AccessChecker, agentGetter teleagent.Getter, singer func(context.Context) (ssh.Signer, error)) (_ net.Conn, teleportVersion string, err error) +} + +// ConnMonitor monitors authorized connnections and terminates them when +// session controls dictate so. +type ConnectionMonitor interface { + MonitorConn(ctx context.Context, authCtx *authz.Context, conn net.Conn) (context.Context, error) } // ServerConfig holds creation parameters for Service. type ServerConfig struct { // FIPS indicates whether the cluster if configured - // to run in FIPS mode + // to run in FIPS mode. FIPS bool - // Logger provides a mechanism to log output + // Logger provides a mechanism to log output. Logger logrus.FieldLogger - // Dialer is used to establish remote connections + // Dialer is used to establish remote connections. Dialer Dialer + // SignerFn is used to create an [ssh.Signer] for an authenticated connection. + SignerFn func(authzCtx *authz.Context) func(context.Context) (ssh.Signer, error) + // ConnectionMonitor is used to monitor the connection for activity and terminate it + // when conditions are met. + ConnectionMonitor ConnectionMonitor + // LocalAddr is the local address of the service. + LocalAddr net.Addr // agentGetterFn used by tests to serve the agent directly agentGetterFn func(rw io.ReadWriter) teleagent.Getter - accessCheckerFn func(info credentials.AuthInfo) (services.AccessChecker, error) + // authzContextFn used by tests to inject an access checker + authzContextFn func(info credentials.AuthInfo) (*authz.Context, error) } // CheckAndSetDefaults ensures required parameters are set @@ -63,6 +79,10 @@ func (c *ServerConfig) CheckAndSetDefaults() error { return trace.BadParameter("parameter Dialer required") } + if c.LocalAddr == nil { + return trace.BadParameter("parameter LocalAddr required") + } + if c.Logger == nil { c.Logger = utils.NewLogger().WithField(trace.Component, "transport") } @@ -75,14 +95,14 @@ func (c *ServerConfig) CheckAndSetDefaults() error { } } - if c.accessCheckerFn == nil { - c.accessCheckerFn = func(info credentials.AuthInfo) (services.AccessChecker, error) { + if c.authzContextFn == nil { + c.authzContextFn = func(info credentials.AuthInfo) (*authz.Context, error) { identityInfo, ok := info.(auth.IdentityInfo) if !ok { return nil, trace.AccessDenied("client is not authenticated") } - return identityInfo.AuthContext.Checker, nil + return identityInfo.AuthContext, nil } } @@ -120,7 +140,13 @@ func (s *Service) ProxyCluster(stream transportv1pb.TransportService_ProxyCluste return trace.Wrap(err, "failed receiving first frame") } - conn, err := s.cfg.Dialer.DialSite(stream.Context(), req.Cluster) + ctx := stream.Context() + p, ok := peer.FromContext(ctx) + if !ok { + return trace.BadParameter("unable to find peer") + } + + conn, err := s.cfg.Dialer.DialSite(ctx, req.Cluster, p.Addr, s.cfg.LocalAddr) if err != nil { return trace.Wrap(err, "failed dialing cluster %q", req.Cluster) } @@ -139,7 +165,7 @@ func (s *Service) ProxyCluster(stream transportv1pb.TransportService_ProxyCluste return trace.Wrap(err, "failed constructing streamer") } - return trace.Wrap(utils.ProxyConn(stream.Context(), conn, streamRW)) + return trace.Wrap(utils.ProxyConn(ctx, conn, streamRW)) } // clusterStream implements the [streamutils.Source] interface @@ -168,7 +194,7 @@ func (c clusterStream) Send(frame []byte) error { // ProxySSH establishes a connection to a host and proxies both the SSH and SSH // Agent protocol over the stream. The first request from the client must contain // a valid dial target before the connection can be established. -func (s *Service) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) error { +func (s *Service) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) (err error) { ctx := stream.Context() p, ok := peer.FromContext(ctx) @@ -176,7 +202,7 @@ func (s *Service) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) return trace.BadParameter("unable to find peer") } - checker, err := s.cfg.accessCheckerFn(p.AuthInfo) + authzContext, err := s.cfg.authzContextFn(p.AuthInfo) if err != nil { return trace.Wrap(err) } @@ -235,15 +261,34 @@ func (s *Service) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) } defer agentStreamRW.Close() - conn, err := s.cfg.Dialer.DialHost(ctx, p.Addr, host, port, req.DialTarget.Cluster, checker, s.cfg.agentGetterFn(agentStreamRW)) + // create a reader/writer for SSH protocol + sshStreamRW, err := streamutils.NewReadWriter(sshStream) + if err != nil { + return trace.Wrap(err, "failed constructing ssh streamer") + } + + signer := s.cfg.SignerFn(authzContext) + hostConn, _, err := s.cfg.Dialer.DialHost(ctx, p.Addr, s.cfg.LocalAddr, host, port, req.DialTarget.Cluster, authzContext.Checker, s.cfg.agentGetterFn(agentStreamRW), signer) if err != nil { return trace.Wrap(err, "failed to dial target host") } - // create a reader/writer for SSH protocol - sshStreamRW, err := streamutils.NewReadWriter(sshStream) + // ensure the connection to the target host + // gets closed when exiting + defer func() { + hostConn.Close() + }() + + targetAddr, err := utils.ParseAddr(req.DialTarget.HostPort) if err != nil { - return trace.Wrap(err, "failed constructing ssh streamer") + return trace.Wrap(err) + } + + // monitor the user connection + userConn := streamutils.NewConn(sshStreamRW, p.Addr, targetAddr) + monitorCtx, err := s.cfg.ConnectionMonitor.MonitorConn(ctx, authzContext, userConn) + if err != nil { + return trace.Wrap(err) } // send back the cluster details to alert the other side that @@ -254,8 +299,8 @@ func (s *Service) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) return trace.Wrap(err, "failed sending cluster details ") } - // copy data to/from the connection and ssh stream - return trace.Wrap(utils.ProxyConn(ctx, conn, sshStreamRW)) + // copy data to/from the host/user + return trace.Wrap(utils.ProxyConn(monitorCtx, hostConn, userConn)) } // sshStream implements the [streamutils.Source] interface diff --git a/lib/srv/transport/transportv1/transport_test.go b/lib/srv/transport/transportv1/transport_test.go index 8085d7efc712a..1a1c54f2b6b06 100644 --- a/lib/srv/transport/transportv1/transport_test.go +++ b/lib/srv/transport/transportv1/transport_test.go @@ -41,6 +41,7 @@ import ( transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" streamutils "github.com/gravitational/teleport/api/utils/grpc/stream" + "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" @@ -99,7 +100,7 @@ type fakeDialer struct { hostConns map[string]net.Conn } -func (f fakeDialer) DialSite(ctx context.Context, clusterName string) (net.Conn, error) { +func (f fakeDialer) DialSite(ctx context.Context, clusterName string, clientSrcAddr, clientDstAddr net.Addr) (net.Conn, error) { conn, ok := f.siteConns[clusterName] if !ok { return nil, trace.NotFound(clusterName) @@ -108,14 +109,14 @@ func (f fakeDialer) DialSite(ctx context.Context, clusterName string) (net.Conn, return conn, nil } -func (f fakeDialer) DialHost(ctx context.Context, from net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter) (net.Conn, error) { - key := fmt.Sprintf("%s.%s.%s", host, port, clusterName) +func (f fakeDialer) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.Addr, host, port, cluster string, checker services.AccessChecker, agentGetter teleagent.Getter, singer func(context.Context) (ssh.Signer, error)) (_ net.Conn, teleportVersion string, err error) { + key := fmt.Sprintf("%s.%s.%s", host, port, cluster) conn, ok := f.hostConns[key] if !ok { - return nil, trace.NotFound(key) + return nil, "", trace.NotFound(key) } - return conn, nil + return conn, "", nil } // testPack used to test a [Service]. @@ -178,6 +179,18 @@ func newServer(t *testing.T, cfg ServerConfig) testPack { } } +func fakeSigner(authzCtx *authz.Context) func(context.Context) (ssh.Signer, error) { + return func(context.Context) (ssh.Signer, error) { + return nil, nil + } +} + +type fakeMonitor struct{} + +func (f fakeMonitor) MonitorConn(ctx context.Context, authCtx *authz.Context, conn net.Conn) (context.Context, error) { + return ctx, nil +} + // TestService_GetClusterDetails validates that a [Service] returns // the expected cluster details. func TestService_GetClusterDetails(t *testing.T) { @@ -200,8 +213,12 @@ func TestService_GetClusterDetails(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() srv := newServer(t, ServerConfig{ - Dialer: fakeDialer{}, - FIPS: test.FIPS, + Dialer: fakeDialer{}, + Logger: utils.NewLoggerForTests(), + FIPS: test.FIPS, + SignerFn: fakeSigner, + ConnectionMonitor: fakeMonitor{}, + LocalAddr: &utils.NetAddr{}, }) resp, err := srv.Client.GetClusterDetails(context.Background(), &transportv1pb.GetClusterDetailsRequest{}) @@ -280,6 +297,10 @@ func TestService_ProxyCluster(t *testing.T) { cluster: conn, }, }, + Logger: utils.NewLoggerForTests(), + SignerFn: fakeSigner, + ConnectionMonitor: fakeMonitor{}, + LocalAddr: &utils.NetAddr{}, }) stream, err := srv.Client.ProxyCluster(context.Background()) @@ -407,8 +428,18 @@ func TestService_ProxySSH_Errors(t *testing.T) { fakeHost: conn, }, }, - Logger: utils.NewLoggerForTests(), - accessCheckerFn: test.checkerFn, + SignerFn: fakeSigner, + ConnectionMonitor: fakeMonitor{}, + Logger: utils.NewLoggerForTests(), + LocalAddr: &utils.NetAddr{}, + authzContextFn: func(info credentials.AuthInfo) (*authz.Context, error) { + checker, err := test.checkerFn(info) + if err != nil { + return nil, trace.Wrap(err) + } + + return &authz.Context{Checker: checker}, nil + }, }) stream, err := srv.Client.ProxySSH(context.Background()) @@ -461,7 +492,11 @@ func TestService_ProxySSH(t *testing.T) { // create a server that will open a new connection to the // ssh server created above on each dial request srv := newServer(t, ServerConfig{ - Dialer: sshSrv, + Dialer: sshSrv, + SignerFn: fakeSigner, + Logger: utils.NewLoggerForTests(), + LocalAddr: &utils.NetAddr{}, + ConnectionMonitor: fakeMonitor{}, agentGetterFn: func(rw io.ReadWriter) teleagent.Getter { return func() (teleagent.Agent, error) { srw, ok := rw.(*streamutils.ReadWriter) @@ -474,8 +509,8 @@ func TestService_ProxySSH(t *testing.T) { }, nil } }, - accessCheckerFn: func(info credentials.AuthInfo) (services.AccessChecker, error) { - return fakeChecker{}, nil + authzContextFn: func(info credentials.AuthInfo) (*authz.Context, error) { + return &authz.Context{Checker: fakeChecker{}}, nil }, }) @@ -619,7 +654,7 @@ type sshServer struct { } // DialSite returns a connection to the sshServer -func (s *sshServer) DialSite(context.Context, string) (net.Conn, error) { +func (s *sshServer) DialSite(ctx context.Context, clusterName string, clientSrcAddr, clientDstAddr net.Addr) (net.Conn, error) { conn, err := s.dial() if err != nil { return nil, trace.Wrap(err) @@ -632,31 +667,31 @@ func (s *sshServer) DialSite(context.Context, string) (net.Conn, error) { // nil and is of type testAgent, then the server will serve its keyring // over the underlying [streamutils.ReadWriter] so that tests can exercise // ssh agent multiplexing. -func (s *sshServer) DialHost(ctx context.Context, from net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter) (net.Conn, error) { +func (s *sshServer) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.Addr, host, port, cluster string, checker services.AccessChecker, agentGetter teleagent.Getter, singer func(context.Context) (ssh.Signer, error)) (_ net.Conn, teleportVersion string, err error) { conn, err := s.dial() if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } if agentGetter == nil { - return conn, nil + return conn, "", nil } agnt, err := agentGetter() if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } rw, ok := agnt.(testAgent) if !ok { - return conn, nil + return conn, "", nil } go func() { agent.ServeAgent(s.keyring, rw) }() - return conn, nil + return conn, "", nil } func (s *sshServer) Run() { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index bcdba808962a4..5ce2895145e68 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -6741,6 +6741,10 @@ func (mock authProviderMock) GenerateUserSingleUseCerts(ctx context.Context) (au return nil, nil } +func (mock authProviderMock) GenerateOpenSSHCert(ctx context.Context, req *authproto.OpenSSHCertRequest) (*authproto.OpenSSHCert, error) { + return nil, nil +} + type terminalOpt func(t *TerminalRequest) func withSessionID(sid session.ID) terminalOpt { diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 95896741e81ce..47874695870e5 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -95,6 +95,7 @@ type AuthProvider interface { GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) IsMFARequired(ctx context.Context, req *authproto.IsMFARequiredRequest) (*authproto.IsMFARequiredResponse, error) GenerateUserSingleUseCerts(ctx context.Context) (authproto.AuthService_GenerateUserSingleUseCertsClient, error) + GenerateOpenSSHCert(ctx context.Context, req *authproto.OpenSSHCertRequest) (*authproto.OpenSSHCert, error) } // NewTerminal creates a web-based terminal based on WebSockets and returns a @@ -633,16 +634,15 @@ func (t *TerminalHandler) streamTerminal(ws *websocket.Conn, tc *client.Teleport getAgent := func() (teleagent.Agent, error) { return teleagent.NopCloser(tc.LocalAgent()), nil } - signerCreator := func() (ssh.Signer, error) { - cert, err := t.ctx.GetSSHCertificate() - if err != nil { - return nil, trace.Wrap(err) - } - validBefore := time.Unix(int64(cert.ValidBefore), 0) - ttl := time.Until(validBefore) - return agentless.CreateAuthSigner(t.terminalContext, t.ctx.GetUser(), tc.SiteName, ttl, t.router) + cert, err := t.ctx.GetSSHCertificate() + if err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failed to get certificate") + t.writeError(err) + return } - conn, _, err := t.router.DialHost(ctx, ws.RemoteAddr(), ws.LocalAddr(), t.sessionData.ServerID, strconv.Itoa(t.sessionData.ServerHostPort), tc.SiteName, accessChecker, getAgent, signerCreator) + + signer := agentless.SignerFromSSHCertificate(cert, t.authProvider) + conn, _, err := t.router.DialHost(ctx, ws.RemoteAddr(), ws.LocalAddr(), t.sessionData.ServerID, strconv.Itoa(t.sessionData.ServerHostPort), tc.SiteName, accessChecker, getAgent, signer) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host.") @@ -715,7 +715,7 @@ func (t *TerminalHandler) streamTerminal(ws *websocket.Conn, tc *client.Teleport sshConfig.Auth = tc.AuthMethods // connect to the node again with the new certs - conn, _, err = t.router.DialHost(ctx, ws.RemoteAddr(), ws.LocalAddr(), t.sessionData.ServerID, strconv.Itoa(t.sessionData.ServerHostPort), tc.SiteName, accessChecker, getAgent, signerCreator) + conn, _, err = t.router.DialHost(ctx, ws.RemoteAddr(), ws.LocalAddr(), t.sessionData.ServerID, strconv.Itoa(t.sessionData.ServerHostPort), tc.SiteName, accessChecker, getAgent, signer) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host") t.writeError(err)