diff --git a/lib/vnet/admin_process_common.go b/lib/vnet/admin_process_common.go index 47c1793dbe70f..2aad6c5b1967f 100644 --- a/lib/vnet/admin_process_common.go +++ b/lib/vnet/admin_process_common.go @@ -22,12 +22,19 @@ import ( ) func newNetworkStackConfig(tun tunDevice, clt *clientApplicationServiceClient) (*networkStackConfig, error) { - sshProvider := newSSHProvider(sshProviderConfig{clt: clt}) + clock := clockwork.NewRealClock() + sshProvider, err := newSSHProvider(sshProviderConfig{ + clt: clt, + clock: clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } tcpHandlerResolver := newTCPHandlerResolver(&tcpHandlerResolverConfig{ clt: clt, appProvider: newAppProvider(clt), sshProvider: sshProvider, - clock: clockwork.NewRealClock(), + clock: clock, }) ipv6Prefix, err := newIPv6Prefix() if err != nil { diff --git a/lib/vnet/ssh_handler.go b/lib/vnet/ssh_handler.go index 627665913abaf..f8e8e831ea24b 100644 --- a/lib/vnet/ssh_handler.go +++ b/lib/vnet/ssh_handler.go @@ -18,9 +18,15 @@ package vnet import ( "context" + "crypto/rand" "net" + "strings" "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/cryptosuites" ) // sshHandler handles incoming VNet SSH connections. @@ -61,13 +67,83 @@ func (h *sshHandler) handleTCPConnectorWithTargetConn( connector func() (net.Conn, error), targetConn net.Conn, ) error { - // For now we accept the incoming TCP conn to indicate that the node exists, - // but SSH connection forwarding is not implemented yet so we immediately - // close it. + hostCert, err := h.newHostCert(ctx) + if err != nil { + return trace.Wrap(err) + } + localConn, err := connector() if err != nil { return trace.Wrap(err) } - localConn.Close() + defer localConn.Close() + + // For now we accept the incoming SSH connection but forwarding to the + // target is not implemented yet so we immediately close it. + serverConfig := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if !sshutils.KeysEqual(h.cfg.sshProvider.trustedUserPublicKey, key) { + return nil, trace.AccessDenied("SSH client public key is not trusted") + } + return nil, nil + }, + } + serverConfig.AddHostKey(hostCert) + serverConn, chans, reqs, err := ssh.NewServerConn(localConn, serverConfig) + if err != nil { + return trace.Wrap(err, "accepting incoming SSH connection") + } + // Immediately close the connection but make sure to drain the channels. + serverConn.Close() + go ssh.DiscardRequests(reqs) + go func() { + for newChan := range chans { + _ = newChan.Reject(0, "") + } + }() + target := h.cfg.target + log.DebugContext(ctx, "Accepted incoming SSH connection", + "profile", target.profile, + "cluster", target.cluster, + "host", target.host, + "user", serverConn.User(), + ) return trace.NotImplemented("VNet SSH connection forwarding is not yet implemented") } + +func (h *sshHandler) newHostCert(ctx context.Context) (ssh.Signer, error) { + // If the user typed "ssh host.com" or "ssh host.com." our DNS handler will + // only see the fully-qualified variant with the trailing "." but the SSH + // client treats them differently, we need both in the principals if we want + // the cert to be trusted in both cases. + validPrincipals := []string{ + h.cfg.target.fqdn, + strings.TrimSuffix(h.cfg.target.fqdn, "."), + } + // We generate an ephemeral key for every connection, Ed25519 is fast and + // well supported. + hostKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519) + if err != nil { + return nil, trace.Wrap(err, "generating SSH host key") + } + hostSigner, err := ssh.NewSignerFromSigner(hostKey) + if err != nil { + return nil, trace.Wrap(err) + } + cert := &ssh.Certificate{ + Key: hostSigner.PublicKey(), + Serial: 1, + CertType: ssh.HostCert, + ValidPrincipals: validPrincipals, + // This cert will only ever be used to handle this one SSH connection, + // the private key is held only in memory, the issuing CA is regenerated + // every time this process restarts and will only be trusted on this one + // host. The expiry doesn't matter. + ValidBefore: ssh.CertTimeInfinity, + } + if err := cert.SignCert(rand.Reader, h.cfg.sshProvider.hostCASigner); err != nil { + return nil, trace.Wrap(err, "signing SSH host cert") + } + certSigner, err := ssh.NewCertSigner(cert, hostSigner) + return certSigner, trace.Wrap(err) +} diff --git a/lib/vnet/ssh_provider.go b/lib/vnet/ssh_provider.go index ace2db9bf17a5..1e0cc00e2ebd1 100644 --- a/lib/vnet/ssh_provider.go +++ b/lib/vnet/ssh_provider.go @@ -24,19 +24,26 @@ import ( "strings" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "golang.org/x/crypto/ssh" proxyclient "github.com/gravitational/teleport/api/client/proxy" vnetv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/vnet/v1" + "github.com/gravitational/teleport/lib/cryptosuites" ) // sshProvider provides methods necessary for VNet SSH access. type sshProvider struct { cfg sshProviderConfig + // hostCASigner is the host CA key used internally in VNet to terminate + // connections from clients, it is not a Teleport CA used by any cluster. + hostCASigner ssh.Signer + trustedUserPublicKey ssh.PublicKey } type sshProviderConfig struct { - clt *clientApplicationServiceClient + clt *clientApplicationServiceClient + clock clockwork.Clock // overrideNodeDialer can be used in tests to dial SSH nodes with the real // TLS configuration but without setting up the proxy transport service. overrideNodeDialer func( @@ -45,12 +52,45 @@ type sshProviderConfig struct { tlsConfig *tls.Config, dialOpts *vnetv1.DialOptions, ) (net.Conn, error) + // hostCASigner can be used in tests to set a specific key for the SSH host CA. + hostCASigner ssh.Signer + // trustedUserPublicKey can be used in tests to set a specific trusted user + // SSH key. + trustedUserPublicKey ssh.PublicKey } -func newSSHProvider(cfg sshProviderConfig) *sshProvider { - return &sshProvider{ - cfg: cfg, +func newSSHProvider(cfg sshProviderConfig) (*sshProvider, error) { + hostCASigner := cfg.hostCASigner + if hostCASigner == nil { + // TODO(nklaassen): write host CA public key to $TELEPORT_HOME/vnet_known_hosts + hostKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519) + if err != nil { + return nil, trace.Wrap(err) + } + hostCASigner, err = ssh.NewSignerFromSigner(hostKey) + if err != nil { + return nil, trace.Wrap(err) + } + } + trustedUserPublicKey := cfg.trustedUserPublicKey + if trustedUserPublicKey == nil { + // TODO(nklaassen): check if $TELEPORT_HOME/id_vnet.pub exists. + // If it does, read that file and trust it. + // If not, generate the keypair and write it to $TELEPORT_HOME/id_vnet. + userKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519) + if err != nil { + return nil, trace.Wrap(err) + } + trustedUserPublicKey, err = ssh.NewPublicKey(userKey.Public()) + if err != nil { + return nil, trace.Wrap(err) + } } + return &sshProvider{ + cfg: cfg, + hostCASigner: hostCASigner, + trustedUserPublicKey: trustedUserPublicKey, + }, nil } // dial dials the target SSH host. @@ -142,7 +182,10 @@ func (p *sshProvider) userTLSConfig( } type dialTarget struct { - profile, cluster, host string + fqdn string + profile string + cluster string + host string } func computeDialTarget(matchedCluster *vnetv1.MatchedCluster, fqdn string) dialTarget { @@ -155,6 +198,7 @@ func computeDialTarget(matchedCluster *vnetv1.MatchedCluster, fqdn string) dialT } targetHost = targetHost + ":0" return dialTarget{ + fqdn: fqdn, profile: targetProfile, cluster: targetCluster, host: targetHost, diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index d5a64a67a3bea..3063f7d1a288a 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -43,6 +43,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "google.golang.org/grpc" grpccredentials "google.golang.org/grpc/credentials" "gvisor.dev/gvisor/pkg/tcpip" @@ -58,6 +59,7 @@ import ( "github.com/gravitational/teleport/api/types" typesvnet "github.com/gravitational/teleport/api/types/vnet" "github.com/gravitational/teleport/api/utils/grpc/interceptors" + "github.com/gravitational/teleport/api/utils/sshutils" vnetv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/vnet/v1" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/cryptosuites" @@ -85,8 +87,10 @@ type testPack struct { } type testPackConfig struct { - clock clockwork.Clock - fakeClientApp *fakeClientApp + clock clockwork.Clock + fakeClientApp *fakeClientApp + sshHostCASigner ssh.Signer + sshTrustedUserPublicKey ssh.PublicKey } func newTestPack(t *testing.T, ctx context.Context, cfg testPackConfig) *testPack { @@ -145,10 +149,14 @@ func newTestPack(t *testing.T, ctx context.Context, cfg testPackConfig) *testPac // interface with fakeClientApp via the gRPC client. clt := runTestClientApplicationService(t, ctx, cfg.clock, cfg.fakeClientApp) appProvider := newAppProvider(clt) - sshProvider := newSSHProvider(sshProviderConfig{ - clt: clt, - overrideNodeDialer: cfg.fakeClientApp.dialSSHNode, + sshProvider, err := newSSHProvider(sshProviderConfig{ + clt: clt, + clock: cfg.clock, + overrideNodeDialer: cfg.fakeClientApp.dialSSHNode, + hostCASigner: cfg.sshHostCASigner, + trustedUserPublicKey: cfg.sshTrustedUserPublicKey, }) + require.NoError(t, err) tcpHandlerResolver := newTCPHandlerResolver(&tcpHandlerResolverConfig{ clt: clt, appProvider: appProvider, @@ -1088,22 +1096,52 @@ func TestSSH(t *testing.T) { signatureAlgorithmSuite: types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_BALANCED_V1, }) + sshHostKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519) + require.NoError(t, err) + sshHostCASigner, err := ssh.NewSignerFromSigner(sshHostKey) + require.NoError(t, err) + + sshUserKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519) + require.NoError(t, err) + sshUserSigner, err := ssh.NewSignerFromSigner(sshUserKey) + require.NoError(t, err) + + badUserKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.Ed25519) + require.NoError(t, err) + badUserSigner, err := ssh.NewSignerFromSigner(badUserKey) + require.NoError(t, err) + p := newTestPack(t, ctx, testPackConfig{ - fakeClientApp: clientApp, - clock: clock, + fakeClientApp: clientApp, + clock: clock, + sshHostCASigner: sshHostCASigner, + sshTrustedUserPublicKey: sshUserSigner.PublicKey(), }) for _, tc := range []struct { - dialAddr string - dialPort int - expectCIDR string - expectLookupToFail bool - expectDialToFail bool + dialAddr string + dialPort int + expectCIDR string + expectLookupToFail bool + expectDialToFail bool + sshUser string + sshUserSigner ssh.Signer + expectSSHHandshakeToFail bool }{ { - dialAddr: "node.root1.example.com", - dialPort: 22, - expectCIDR: root1CIDR, + dialAddr: "node.root1.example.com", + dialPort: 22, + expectCIDR: root1CIDR, + sshUser: "testuser", + sshUserSigner: sshUserSigner, + }, + { + // Fully-qualified hostname should also work. + dialAddr: "node.root1.example.com.", + dialPort: 22, + expectCIDR: root1CIDR, + sshUser: "testuser", + sshUserSigner: sshUserSigner, }, { // Dial should fail on non-standard SSH port. @@ -1113,19 +1151,33 @@ func TestSSH(t *testing.T) { expectDialToFail: true, }, { - dialAddr: "node.leaf1.example.com.root1.example.com", - dialPort: 22, - expectCIDR: leaf1CIDR, + dialAddr: "node.root1.example.com", + dialPort: 22, + expectCIDR: root1CIDR, + sshUser: "baduser", + sshUserSigner: badUserSigner, + expectSSHHandshakeToFail: true, }, { - dialAddr: "node.root2.example.com", - dialPort: 22, - expectCIDR: root2CIDR, + dialAddr: "node.leaf1.example.com.root1.example.com", + dialPort: 22, + expectCIDR: leaf1CIDR, + sshUser: "testuser", + sshUserSigner: sshUserSigner, }, { - dialAddr: "node.leaf2.example.com.root2.example.com", - dialPort: 22, - expectCIDR: leaf2CIDR, + dialAddr: "node.root2.example.com", + dialPort: 22, + expectCIDR: root2CIDR, + sshUser: "testuser", + sshUserSigner: sshUserSigner, + }, + { + dialAddr: "node.leaf2.example.com.root2.example.com", + dialPort: 22, + expectCIDR: leaf2CIDR, + sshUser: "testuser", + sshUserSigner: sshUserSigner, }, { // DNS lookup should fail if the FQDN doesn't match any cluster. @@ -1142,14 +1194,13 @@ func TestSSH(t *testing.T) { expectDialToFail: true, }, } { - t.Run(fmt.Sprintf("%s:%d", tc.dialAddr, tc.dialPort), func(t *testing.T) { + t.Run(fmt.Sprintf("%s@%s:%d", tc.sshUser, tc.dialAddr, tc.dialPort), func(t *testing.T) { t.Parallel() lookupCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - // SSH access isn't fully implemented yet, at this point the DNS - // lookup for *. should resolve to an IP in the - // expected CIDR range for the cluster. + // The DNS lookup for *. should resolve to an IP in + // the expected CIDR range for the cluster. resolvedAddrs, err := p.lookupHost(lookupCtx, tc.dialAddr) if tc.expectLookupToFail { require.Error(t, err) @@ -1170,6 +1221,8 @@ func TestSSH(t *testing.T) { "expected CIDR range %s does not include resolved IP %s", expectNet, resolvedIPSuffix) } + // TCP dial the target address, it should fail if the node doesn't + // exist. dialCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() conn, err := p.dialHost(dialCtx, tc.dialAddr, tc.dialPort) @@ -1178,9 +1231,58 @@ func TestSSH(t *testing.T) { return } require.NoError(t, err) - conn.Close() + defer conn.Close() + + // Initiate an SSH connection to the target. At this point the + // handshake should complete successfully as long as the right keys + // are used, but the SSH connection will be immediately closed by + // the server. + certChecker := ssh.CertChecker{ + IsHostAuthority: func(auth ssh.PublicKey, address string) bool { + return sshutils.KeysEqual(auth, sshHostCASigner.PublicKey()) + }, + Clock: clock.Now, + } + clientConfig := &ssh.ClientConfig{ + User: tc.sshUser, + Auth: []ssh.AuthMethod{ssh.PublicKeys(tc.sshUserSigner)}, + HostKeyCallback: certChecker.CheckHostKey, + } + sshConn, _, _, err := ssh.NewClientConn(conn, fmt.Sprintf("%s:%d", tc.dialAddr, tc.dialPort), clientConfig) + if tc.expectSSHHandshakeToFail { + require.Error(t, err, "expected SSH handshake to fail") + return + } + require.NoError(t, err) + defer sshConn.Close() }) } + + // Test that a fresh SSH host cert is used on each connection. + t.Run("ephemeral certs", func(t *testing.T) { + // Set up the SSH client config to capture the host certs it sees. + var checkedHostCerts []*ssh.Certificate + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.PublicKeys(sshUserSigner)}, + HostKeyCallback: func(addr string, remote net.Addr, key ssh.PublicKey) error { + checkedHostCerts = append(checkedHostCerts, key.(*ssh.Certificate)) + return nil + }, + } + const connections = 3 + for range connections { + conn, err := p.dialHost(ctx, "node.root1.example.com", 22) + require.NoError(t, err) + sshConn, _, _, err := ssh.NewClientConn(conn, "node.root1.example.com:22", clientConfig) + require.NoError(t, err) + sshConn.Close() + } + require.Len(t, checkedHostCerts, connections) + for i := range connections - 1 { + require.NotEqual(t, checkedHostCerts[i], checkedHostCerts[i+1]) + } + }) } func randomULAAddress() (tcpip.Address, error) {