diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 69c4e9171f680..4f95c84e6fe08 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -292,7 +292,12 @@ func (l *LocalProxy) CheckDBCerts(dbRoute tlsca.RouteToDatabase) error { return trace.Wrap(err) } - // Check the subject matches. + return trace.Wrap(CheckCertSubject(cert, dbRoute)) +} + +// CheckCertSubject checks if the route to the database from the cert matches the provided route in +// terms of username and database (if present). +func CheckCertSubject(cert *x509.Certificate, dbRoute tlsca.RouteToDatabase) error { identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) if err != nil { return trace.Wrap(err) @@ -305,6 +310,7 @@ func (l *LocalProxy) CheckDBCerts(dbRoute tlsca.RouteToDatabase) error { return trace.Errorf("certificate subject is for database name %s, but need %s", identity.RouteToDatabase.Database, dbRoute.Database) } + return nil } diff --git a/lib/teleterm/api/uri/uri.go b/lib/teleterm/api/uri/uri.go index 47c2f6fc76e76..a0107606f1e51 100644 --- a/lib/teleterm/api/uri/uri.go +++ b/lib/teleterm/api/uri/uri.go @@ -25,6 +25,8 @@ import ( var pathClusters = urlpath.New("/clusters/:cluster/*") var pathLeafClusters = urlpath.New("/clusters/:cluster/leaves/:leaf/*") +var pathDbs = urlpath.New("/clusters/:cluster/dbs/:dbName") +var pathLeafDbs = urlpath.New("/clusters/:cluster/leaves/:leaf/dbs/:dbName") // New creates an instance of ResourceURI func New(path string) ResourceURI { @@ -92,6 +94,21 @@ func (r ResourceURI) GetLeafClusterName() string { return result.Params["leaf"] } +// GetDbName extracts the database name from r. Returns an empty string if path is not a database URI. +func (r ResourceURI) GetDbName() string { + result, ok := pathDbs.Match(r.path) + if ok { + return result.Params["dbName"] + } + + result, ok = pathLeafDbs.Match(r.path) + if ok { + return result.Params["dbName"] + } + + return "" +} + // AppendServer appends server segment to the URI func (r ResourceURI) AppendServer(id string) ResourceURI { r.path = fmt.Sprintf("%v/servers/%v", r.path, id) diff --git a/lib/teleterm/api/uri/uri_test.go b/lib/teleterm/api/uri/uri_test.go index 54a6756d39e40..04273a1c68c9f 100644 --- a/lib/teleterm/api/uri/uri_test.go +++ b/lib/teleterm/api/uri/uri_test.go @@ -26,7 +26,6 @@ import ( ) func TestString(t *testing.T) { - t.Parallel() tests := []struct { in uri.ResourceURI out string @@ -46,10 +45,7 @@ func TestString(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { - t.Parallel() - out := tt.in.String() require.Equal(t, tt.out, out) }) @@ -57,7 +53,6 @@ func TestString(t *testing.T) { } func TestParseClusterURI(t *testing.T) { - t.Parallel() tests := []struct { in string out uri.ResourceURI @@ -81,13 +76,56 @@ func TestParseClusterURI(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.in, func(t *testing.T) { - t.Parallel() - out, err := uri.ParseClusterURI(tt.in) require.NoError(t, err) require.Equal(t, tt.out, out) }) } } + +func TestGetDbName(t *testing.T) { + tests := []struct { + name string + in uri.ResourceURI + out string + }{ + { + name: "returns root cluster db name", + in: uri.NewClusterURI("foo").AppendDB("postgres"), + out: "postgres", + }, + { + name: "returns leaf cluster db name", + in: uri.NewClusterURI("foo").AppendLeafCluster("bar").AppendDB("postgres"), + out: "postgres", + }, + { + name: "returns empty string when given root cluster URI", + in: uri.NewClusterURI("foo"), + out: "", + }, + { + name: "returns empty string when given leaf cluster URI", + in: uri.NewClusterURI("foo").AppendLeafCluster("bar"), + out: "", + }, + { + name: "returns empty string when given root cluster non-db resource URI", + in: uri.NewClusterURI("foo").AppendKube("k8s"), + out: "", + }, + { + name: "returns empty string when given leaf cluster non-db resource URI", + in: uri.NewClusterURI("foo").AppendLeafCluster("bar").AppendKube("k8s"), + out: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := tt.in.GetDbName() + require.Equal(t, tt.out, out) + }) + } +} diff --git a/lib/teleterm/clusters/cluster_databases.go b/lib/teleterm/clusters/cluster_databases.go index b90ba48c62785..b946c91c48def 100644 --- a/lib/teleterm/clusters/cluster_databases.go +++ b/lib/teleterm/clusters/cluster_databases.go @@ -154,7 +154,7 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) return response, nil } -// ReissueDBCerts issues new certificates for specific DB access +// ReissueDBCerts issues new certificates for specific DB access and saves them to disk. func (c *Cluster) ReissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database diff --git a/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go b/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go index 22a696a172748..81f9f6dce7fca 100644 --- a/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go +++ b/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go @@ -27,6 +27,8 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/gateway" + "github.com/gravitational/teleport/lib/teleterm/gatewaytest" + "github.com/gravitational/teleport/lib/tlsca" ) type fakeExec struct{} @@ -89,17 +91,29 @@ func TestDbcmdCLICommandProviderGetCommand(t *testing.T) { clusters: []*Cluster{&cluster}, } dbcmdCLICommandProvider := NewDbcmdCLICommandProvider(fakeStorage, fakeExec{}) + + keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + Username: "alice", + Groups: []string{"test-group"}, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: "foo", + Protocol: defaults.ProtocolPostgres, + Username: "alice", + }, + }) + gateway, err := gateway.New( gateway.Config{ TargetURI: cluster.URI.AppendDB("foo").String(), TargetName: "foo", + TargetUser: "alice", TargetSubresourceName: tc.targetSubresourceName, Protocol: defaults.ProtocolPostgres, LocalAddress: "localhost", WebProxyAddr: "localhost:1337", Insecure: true, - CertPath: "../../../fixtures/certs/proxy1.pem", - KeyPath: "../../../fixtures/certs/proxy1-key.pem", + CertPath: keyPairPaths.CertPath, + KeyPath: keyPairPaths.KeyPath, CLICommandProvider: dbcmdCLICommandProvider, TCPPortAllocator: gateway.NetTCPPortAllocator{}, }, @@ -122,6 +136,17 @@ func TestDbcmdCLICommandProviderGetCommand_ReturnsErrorIfClusterIsNotFound(t *te clusters: []*Cluster{}, } dbcmdCLICommandProvider := NewDbcmdCLICommandProvider(fakeStorage, fakeExec{}) + + keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + Username: "alice", + Groups: []string{"test-group"}, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: "foo", + Protocol: defaults.ProtocolPostgres, + Username: "alice", + }, + }) + gateway, err := gateway.New( gateway.Config{ TargetURI: uri.NewClusterURI("quux").AppendDB("foo").String(), @@ -132,8 +157,8 @@ func TestDbcmdCLICommandProviderGetCommand_ReturnsErrorIfClusterIsNotFound(t *te LocalAddress: "localhost", WebProxyAddr: "localhost:1337", Insecure: true, - CertPath: "../../../fixtures/certs/proxy1.pem", - KeyPath: "../../../fixtures/certs/proxy1-key.pem", + CertPath: keyPairPaths.CertPath, + KeyPath: keyPairPaths.KeyPath, CLICommandProvider: dbcmdCLICommandProvider, TCPPortAllocator: gateway.NetTCPPortAllocator{}, }, diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index ea9d60bb8bf29..111c21d9a58ab 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -31,6 +31,7 @@ import ( "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/gateway" "github.com/gravitational/teleport/lib/teleterm/gatewaytest" + "github.com/gravitational/teleport/lib/tlsca" ) type mockGatewayCreator struct { @@ -46,6 +47,18 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. hs.Close() }) + resourceURI := uri.New(params.TargetURI) + + keyPairPaths := gatewaytest.MustGenAndSaveCert(m.t, tlsca.Identity{ + Username: params.TargetUser, + Groups: []string{"test-group"}, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: resourceURI.GetDbName(), + Protocol: defaults.ProtocolPostgres, + Username: params.TargetUser, + }, + }) + gateway, err := gateway.New(gateway.Config{ LocalPort: params.LocalPort, TargetURI: params.TargetURI, @@ -53,8 +66,8 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. TargetName: params.TargetURI, TargetSubresourceName: params.TargetSubresourceName, Protocol: defaults.ProtocolPostgres, - CertPath: "../../../fixtures/certs/proxy1.pem", - KeyPath: "../../../fixtures/certs/proxy1-key.pem", + CertPath: keyPairPaths.CertPath, + KeyPath: keyPairPaths.KeyPath, Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), CLICommandProvider: params.CLICommandProvider, diff --git a/lib/teleterm/gateway/gateway.go b/lib/teleterm/gateway/gateway.go index 6b9832c9935c8..3ebe263ccefa0 100644 --- a/lib/teleterm/gateway/gateway.go +++ b/lib/teleterm/gateway/gateway.go @@ -30,6 +30,7 @@ import ( alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -82,6 +83,11 @@ func New(cfg Config) (*Gateway, error) { return nil, trace.Wrap(err) } + if err := checkCertSubject(tlsCert, cfg.RouteToDatabase()); err != nil { + return nil, trace.Wrap(err, + "database certificate check failed, try restarting the database connection") + } + localProxyConfig := alpn.LocalProxyConfig{ InsecureSkipVerify: cfg.Insecure, RemoteProxyAddr: cfg.WebProxyAddr, @@ -226,6 +232,14 @@ func (g *Gateway) CLICommand() (string, error) { return cliCommand, nil } +// RouteToDatabase returns tlsca.RouteToDatabase based on the config of the gateway. +// +// The tlsca.RouteToDatabase.Database field is skipped, as it's an optional field and gateways can +// change their Config.TargetSubresourceName at any moment. +func (g *Gateway) RouteToDatabase() tlsca.RouteToDatabase { + return g.cfg.RouteToDatabase() +} + // ReloadCert loads the key pair from cfg.CertPath & cfg.KeyPath and updates the cert of the running // local proxy. This is typically done after the cert is reissued and saved to disk. // @@ -239,11 +253,32 @@ func (g *Gateway) ReloadCert() error { return trace.Wrap(err) } + if err := checkCertSubject(tlsCert, g.RouteToDatabase()); err != nil { + return trace.Wrap(err, + "database certificate check failed, try restarting the database connection") + } + g.localProxy.SetCerts([]tls.Certificate{tlsCert}) return nil } +// checkCertSubject checks if the cert subject matches the expected db route. +// +// Database certs are scoped per database server but not per database user or database name. +// It might happen that after we save the cert but before we load it, another process obtains a +// cert for another db user. +// +// Before using the cert for the proxy, we have to perform this check. +func checkCertSubject(tlsCert tls.Certificate, dbRoute tlsca.RouteToDatabase) error { + cert, err := utils.TLSCertToX509(tlsCert) + if err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(alpn.CheckCertSubject(cert, dbRoute)) +} + // Gateway describes local proxy that creates a gateway to the remote Teleport resource. // // Gateway is not safe for concurrent use in itself. However, all access to gateways is gated by diff --git a/lib/teleterm/gateway/gateway_test.go b/lib/teleterm/gateway/gateway_test.go index e077ca5e890b6..b412570067997 100644 --- a/lib/teleterm/gateway/gateway_test.go +++ b/lib/teleterm/gateway/gateway_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/gatewaytest" + "github.com/gravitational/teleport/lib/tlsca" ) func TestCLICommandUsesCLICommandProvider(t *testing.T) { @@ -52,14 +53,24 @@ func TestGatewayStart(t *testing.T) { hs.Close() }) + keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + Username: "alice", + Groups: []string{"test-group"}, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: "foo", + Protocol: defaults.ProtocolPostgres, + Username: "alice", + }, + }) + gateway, err := New( Config{ TargetName: "foo", TargetURI: uri.NewClusterURI("bar").AppendDB("foo").String(), TargetUser: "alice", Protocol: defaults.ProtocolPostgres, - CertPath: "../../../fixtures/certs/proxy1.pem", - KeyPath: "../../../fixtures/certs/proxy1-key.pem", + CertPath: keyPairPaths.CertPath, + KeyPath: keyPairPaths.KeyPath, Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), CLICommandProvider: mockCLICommandProvider{}, @@ -148,14 +159,24 @@ func createGateway(t *testing.T, tcpPortAllocator TCPPortAllocator) *Gateway { hs.Close() }) + keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + Username: "alice", + Groups: []string{"test-group"}, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: "foo", + Protocol: defaults.ProtocolPostgres, + Username: "alice", + }, + }) + gateway, err := New( Config{ TargetName: "foo", TargetURI: uri.NewClusterURI("bar").AppendDB("foo").String(), TargetUser: "alice", Protocol: defaults.ProtocolPostgres, - CertPath: "../../../fixtures/certs/proxy1.pem", - KeyPath: "../../../fixtures/certs/proxy1-key.pem", + CertPath: keyPairPaths.CertPath, + KeyPath: keyPairPaths.KeyPath, Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), CLICommandProvider: mockCLICommandProvider{}, diff --git a/lib/teleterm/gatewaytest/helpers.go b/lib/teleterm/gatewaytest/helpers.go index 4f2620123383e..e8e7be6d0a17e 100644 --- a/lib/teleterm/gatewaytest/helpers.go +++ b/lib/teleterm/gatewaytest/helpers.go @@ -15,14 +15,25 @@ package gatewaytest import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" "net" + "os" "testing" "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/tlsca" ) const timeout = time.Second * 5 @@ -128,3 +139,83 @@ func (m *MockListener) Addr() net.Addr { func (m *MockListener) RealAddr() net.Addr { return m.realListener.Addr() } + +type KeyPairPaths struct { + CertPath string + KeyPath string +} + +func MustGenAndSaveCert(t *testing.T, identity tlsca.Identity) KeyPairPaths { + t.Helper() + + dir := t.TempDir() + + ca := mustGenCACert(t) + + tlsCert := mustGenCertSignedWithCA(t, ca, identity) + + privateKey, ok := tlsCert.PrivateKey.(*rsa.PrivateKey) + require.True(t, ok, "Failed to cast tlsCert.PrivateKey") + + // Save the cert. + + certFile, err := os.CreateTemp(dir, "cert") + require.NoError(t, err) + + pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tlsCert.Certificate[0]}) + + _, err = certFile.Write(pemCert) + require.NoError(t, err) + require.NoError(t, certFile.Close()) + + // Save the private key. + + keyFile, err := os.CreateTemp(dir, "key") + require.NoError(t, err) + + pemPrivateKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + _, err = keyFile.Write(pemPrivateKey) + require.NoError(t, err) + require.NoError(t, keyFile.Close()) + + return KeyPairPaths{ + CertPath: certFile.Name(), + KeyPath: keyFile.Name(), + } +} + +func mustGenCACert(t *testing.T) *tlsca.CertAuthority { + caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{ + CommonName: "localhost", + }, []string{"localhost"}, defaults.CATTL) + require.NoError(t, err) + + ca, err := tlsca.FromKeys(caCert, caKey) + require.NoError(t, err) + return ca +} + +func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, identity tlsca.Identity) tls.Certificate { + clock := clockwork.NewRealClock() + subj, err := identity.Subject() + require.NoError(t, err) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + tlsCert, err := ca.GenerateCertificate(tlsca.CertificateRequest{ + Clock: clock, + PublicKey: privateKey.Public(), + Subject: subj, + NotAfter: clock.Now().UTC().Add(time.Minute), + DNSNames: []string{"localhost", "*.localhost"}, + }) + require.NoError(t, err) + + keyRaw := x509.MarshalPKCS1PrivateKey(privateKey) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyRaw}) + cert, err := tls.X509KeyPair(tlsCert, keyPEM) + require.NoError(t, err) + return cert +}