diff --git a/lib/client/db/dbcmd/dbcmd.go b/lib/client/db/dbcmd/dbcmd.go index e72279a1eb298..ab592dac338de 100644 --- a/lib/client/db/dbcmd/dbcmd.go +++ b/lib/client/db/dbcmd/dbcmd.go @@ -120,17 +120,23 @@ type CLICommandBuilder struct { // --cluster flag. Therefore profile.Cluster is not suitable for // determining the target cluster or the root cluster. Use tc.SiteName for // the target cluster and rootCluster for root cluster. - profile *client.ProfileStatus - db *tlsca.RouteToDatabase - host string - port int - options connectionCommandOpts - uid utils.UID + profile *client.ProfileStatus + db *tlsca.RouteToDatabase + host string + port int + options connectionCommandOpts + uid utils.UID + getDatabaseFunc GetDatabaseFunc } -func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus, - db tlsca.RouteToDatabase, rootClusterName string, opts ...ConnectCommandFunc, -) *CLICommandBuilder { +func NewCmdBuilder( + tc *client.TeleportClient, + profile *client.ProfileStatus, + db tlsca.RouteToDatabase, + rootClusterName string, + getDatabaseFunc GetDatabaseFunc, + opts ...ConnectCommandFunc, +) (*CLICommandBuilder, error) { var options connectionCommandOpts for _, opt := range opts { opt(&options) @@ -151,16 +157,21 @@ func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus, options.exe = &SystemExecer{} } - return &CLICommandBuilder{ - tc: tc, - profile: profile, - db: &db, - host: host, - port: port, - options: options, - rootCluster: rootClusterName, - uid: utils.NewRealUID(), + if getDatabaseFunc == nil { + return nil, trace.BadParameter("GetDatabaseFunc is required and cannot be nil") } + + return &CLICommandBuilder{ + tc: tc, + profile: profile, + db: &db, + host: host, + port: port, + options: options, + rootCluster: rootClusterName, + getDatabaseFunc: getDatabaseFunc, + uid: utils.NewRealUID(), + }, nil } // GetConnectCommand returns a command that can connect the user directly to the given database @@ -219,7 +230,7 @@ func (c *CLICommandBuilder) GetConnectCommand(ctx context.Context) (*exec.Cmd, e return c.getClickhouseNativeCommand() case defaults.ProtocolSpanner: - return c.getSpannerCommand() + return c.getSpannerCommand(ctx) } return nil, trace.BadParameter("unsupported database protocol: %v", c.db) @@ -546,12 +557,7 @@ func (c *CLICommandBuilder) getMongoCommand(ctx context.Context) (*exec.Cmd, err } func (c *CLICommandBuilder) getDatabase(ctx context.Context) (types.Database, error) { - // Technically, we can just use tc to get the database. But caller may have - // extra logic so rely on the callback for now. - if c.options.getDatabase == nil { - return nil, trace.NotFound("missing GetDatabaseFunc") - } - db, err := c.options.getDatabase(ctx, c.tc, c.db.ServiceName) + db, err := c.getDatabaseFunc(ctx, c.tc, c.db.ServiceName) return db, trace.Wrap(err) } @@ -749,25 +755,34 @@ func (c *CLICommandBuilder) getDynamoDBCommand() (*exec.Cmd, error) { return exec.Command(awsBin, args...), nil } -func (c *CLICommandBuilder) getSpannerCommand() (*exec.Cmd, error) { +func (c *CLICommandBuilder) getSpannerCommand(ctx context.Context) (*exec.Cmd, error) { if err := c.checkLocalProxyTunnelOnly(false); err != nil { return nil, trace.Wrap(err) } + var ( + gcp types.GCPCloudSQL project, instance, database string ) - if c.options.printFormat { - // default placeholders for a print command if not all info is available + + db, err := c.getDatabase(ctx) + switch { + case err != nil && c.options.printFormat: + // OK to continue in this case, we'll print the placeholders instead. project, instance, database = "", "", "" + case err != nil: + return nil, trace.Wrap(err) + default: + gcp = db.GetGCP() } - if c.options.gcp.ProjectID != "" { - project = c.options.gcp.ProjectID + if gcp.ProjectID != "" { + project = gcp.ProjectID } - if c.options.gcp.InstanceID != "" { - instance = c.options.gcp.InstanceID + if gcp.InstanceID != "" { + instance = gcp.InstanceID } if c.db.Database != "" { database = c.db.Database @@ -947,9 +962,7 @@ type connectionCommandOpts struct { log *logrus.Entry exe Execer password string - gcp types.GCPCloudSQL oracle oracleOpts - getDatabase GetDatabaseFunc } // ConnectCommandFunc is a type for functions returned by the "With*" functions in this package. @@ -1052,24 +1065,9 @@ func WithOracleOpts(canUseTCP bool, hasTCPServers bool) ConnectCommandFunc { } } -// WithGCP adds GCP metadata for the database command to access. -// TODO(greedy52) use GetDatabaseFunc instead. -func WithGCP(gcp types.GCPCloudSQL) ConnectCommandFunc { - return func(opts *connectionCommandOpts) { - opts.gcp = gcp - } -} - // GetDatabaseFunc is a callback to retrieve types.Database. type GetDatabaseFunc func(context.Context, *client.TeleportClient, string) (types.Database, error) -// WithGetDatabaseFunc provides a callback to retrieve types.Database. -func WithGetDatabaseFunc(f GetDatabaseFunc) ConnectCommandFunc { - return func(opts *connectionCommandOpts) { - opts.getDatabase = f - } -} - const ( // envVarMongoServerSelectionTimeoutMS is the environment variable that // controls the server selection timeout used for MongoDB clients. diff --git a/lib/client/db/dbcmd/dbcmd_test.go b/lib/client/db/dbcmd/dbcmd_test.go index 13eb5dd54ef7a..dc3bf109b073d 100644 --- a/lib/client/db/dbcmd/dbcmd_test.go +++ b/lib/client/db/dbcmd/dbcmd_test.go @@ -79,13 +79,14 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { } tests := []struct { - name string - opts []ConnectCommandFunc - dbProtocol string - databaseName string - execer *fakeExec - cmd []string - wantErr bool + name string + opts []ConnectCommandFunc + dbProtocol string + databaseName string + execer *fakeExec + cmd []string + wantErr bool + getDatabaseFunc GetDatabaseFunc }{ { name: "postgres", @@ -327,10 +328,10 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "mongodb (legacy)", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "mydb", - opts: []ConnectCommandFunc{withMongoDBAtlasDatabase()}, + name: "mongodb (legacy)", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + getDatabaseFunc: withMongoDBAtlasDatabase(), execer: &fakeExec{ execOutput: map[string][]byte{ "mongo": []byte("legacy"), @@ -344,10 +345,11 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "mongodb no TLS (legacy)", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "mydb", - opts: []ConnectCommandFunc{WithNoTLS(), withMongoDBAtlasDatabase()}, + name: "mongodb no TLS (legacy)", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + getDatabaseFunc: withMongoDBAtlasDatabase(), + opts: []ConnectCommandFunc{WithNoTLS()}, execer: &fakeExec{ execOutput: map[string][]byte{ "mongo": []byte("legacy"), @@ -359,10 +361,10 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "mongosh no CA", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "mydb", - opts: []ConnectCommandFunc{withMongoDBAtlasDatabase()}, + name: "mongosh no CA", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + getDatabaseFunc: withMongoDBAtlasDatabase(), execer: &fakeExec{ execOutput: map[string][]byte{ "mongosh": []byte("1.1.6"), @@ -376,12 +378,12 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { }, }, { - name: "mongosh", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "mydb", + name: "mongosh", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + getDatabaseFunc: withMongoDBAtlasDatabase(), opts: []ConnectCommandFunc{ WithLocalProxy("localhost", 12345, "/tmp/keys/example.com/cas/example.com.pem"), - withMongoDBAtlasDatabase(), }, execer: &fakeExec{ execOutput: map[string][]byte{ @@ -396,10 +398,11 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { }, }, { - name: "mongosh no TLS", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "mydb", - opts: []ConnectCommandFunc{WithNoTLS(), withMongoDBAtlasDatabase()}, + name: "mongosh no TLS", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + getDatabaseFunc: withMongoDBAtlasDatabase(), + opts: []ConnectCommandFunc{WithNoTLS()}, execer: &fakeExec{ execOutput: map[string][]byte{ "mongosh": []byte("1.1.6"), @@ -410,10 +413,11 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { }, }, { - name: "mongosh preferred", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "mydb", - opts: []ConnectCommandFunc{WithNoTLS(), withMongoDBAtlasDatabase()}, + name: "mongosh preferred", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + getDatabaseFunc: withMongoDBAtlasDatabase(), + opts: []ConnectCommandFunc{WithNoTLS()}, execer: &fakeExec{ execOutput: map[string][]byte{}, // Cannot find either bin. }, @@ -422,10 +426,11 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { }, }, { - name: "DocumentDB", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "docdb", - opts: []ConnectCommandFunc{WithNoTLS(), withDocumentDBDatabase()}, + name: "DocumentDB", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "docdb", + getDatabaseFunc: withDocumentDBDatabase(), + opts: []ConnectCommandFunc{WithNoTLS()}, execer: &fakeExec{ execOutput: map[string][]byte{ // When both are available, legacy mongo is preferred. @@ -439,10 +444,11 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "DocumentDB mongosh", - dbProtocol: defaults.ProtocolMongoDB, - databaseName: "docdb", - opts: []ConnectCommandFunc{WithNoTLS(), withDocumentDBDatabase()}, + name: "DocumentDB mongosh", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "docdb", + getDatabaseFunc: withDocumentDBDatabase(), + opts: []ConnectCommandFunc{WithNoTLS()}, execer: &fakeExec{ execOutput: map[string][]byte{ "mongosh": []byte("1.1.6"), @@ -678,12 +684,12 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "Spanner for exec is ok", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner for exec is ok", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: withSpannerDatabase(types.GCPCloudSQL{ProjectID: "foo-proj", InstanceID: "bar-instance"}), opts: []ConnectCommandFunc{ WithLocalProxy("localhost", 12345, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{ProjectID: "foo-proj", InstanceID: "bar-instance"}), }, execer: &fakeExec{}, databaseName: "googlesql-db", @@ -691,13 +697,13 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "Spanner with print format is ok", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner with print format is ok", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: withSpannerDatabase(types.GCPCloudSQL{ProjectID: "foo-proj", InstanceID: "bar-instance"}), opts: []ConnectCommandFunc{ WithPrintFormat(), WithLocalProxy("localhost", 12345, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{ProjectID: "foo-proj", InstanceID: "bar-instance"}), }, execer: &fakeExec{}, databaseName: "googlesql-db", @@ -705,13 +711,13 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "Spanner with print format and placeholders is ok", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner with print format and placeholders is ok", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: getDatabaseFuncWithError, // When format is set the command can accept error when fetching the database. opts: []ConnectCommandFunc{ WithPrintFormat(), WithLocalProxy("localhost", 12345, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{}), }, execer: &fakeExec{}, databaseName: "", @@ -719,48 +725,48 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { wantErr: false, }, { - name: "Spanner for exec without GCP project is an error", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner for exec without GCP project is an error", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: withSpannerDatabase(types.GCPCloudSQL{InstanceID: "bar-instance"}), opts: []ConnectCommandFunc{ WithLocalProxy("localhost", 12345, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{InstanceID: "bar-instance"}), }, execer: &fakeExec{}, databaseName: "googlesql-db", wantErr: true, }, { - name: "Spanner for exec without GCP instance is an error", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner for exec without GCP instance is an error", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: withSpannerDatabase(types.GCPCloudSQL{ProjectID: "foo-proj"}), opts: []ConnectCommandFunc{ WithLocalProxy("localhost", 12345, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{ProjectID: "foo-proj"}), }, execer: &fakeExec{}, databaseName: "googlesql-db", wantErr: true, }, { - name: "Spanner for exec without database name is an error", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner for exec without database name is an error", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: withSpannerDatabase(types.GCPCloudSQL{ProjectID: "foo-proj"}), opts: []ConnectCommandFunc{ WithLocalProxy("localhost", 12345, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{ProjectID: "foo-proj"}), }, execer: &fakeExec{}, databaseName: "googlesql-db", wantErr: true, }, { - name: "Spanner without a local proxy is an error", - dbProtocol: defaults.ProtocolSpanner, + name: "Spanner without a local proxy is an error", + dbProtocol: defaults.ProtocolSpanner, + getDatabaseFunc: withSpannerDatabase(types.GCPCloudSQL{ProjectID: "foo-proj", InstanceID: "bar-instance"}), opts: []ConnectCommandFunc{ WithLocalProxy("", 0, ""), WithNoTLS(), - WithGCP(types.GCPCloudSQL{ProjectID: "foo-proj", InstanceID: "bar-instance"}), }, execer: &fakeExec{}, databaseName: "googlesql-db", @@ -792,7 +798,13 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { WithExecer(tt.execer), }, tt.opts...) - c := NewCmdBuilder(tc, profile, database, "root", opts...) + getDatabaseFunc := tt.getDatabaseFunc + if getDatabaseFunc == nil { + getDatabaseFunc = getDatabaseFuncWithError + } + + c, err := NewCmdBuilder(tc, profile, database, "root", getDatabaseFunc, opts...) + require.NoError(t, err) c.uid = utils.NewFakeUID() got, err := c.GetConnectCommand(context.Background()) if tt.wantErr { @@ -825,13 +837,14 @@ func TestCLICommandBuilderGetConnectCommandAlternatives(t *testing.T) { } tests := []struct { - name string - opts []ConnectCommandFunc - dbProtocol string - databaseName string - execer *fakeExec - cmd map[string][]string - wantErr bool + name string + opts []ConnectCommandFunc + dbProtocol string + databaseName string + execer *fakeExec + cmd map[string][]string + wantErr bool + getDatabaseFunc GetDatabaseFunc }{ { name: "postgres no TLS", @@ -955,7 +968,13 @@ func TestCLICommandBuilderGetConnectCommandAlternatives(t *testing.T) { WithExecer(tt.execer), }, tt.opts...) - c := NewCmdBuilder(tc, profile, database, "root", opts...) + getDatabaseFunc := tt.getDatabaseFunc + if getDatabaseFunc == nil { + getDatabaseFunc = getDatabaseFuncWithError + } + + c, err := NewCmdBuilder(tc, profile, database, "root", getDatabaseFunc, opts...) + require.NoError(t, err) c.uid = utils.NewFakeUID() commandOptions, err := c.GetConnectCommandAlternatives(context.Background()) @@ -995,12 +1014,13 @@ func TestConvertCommandError(t *testing.T) { } tests := []struct { - desc string - dbProtocol string - execer *fakeExec - stderr []byte - wantBin string - wantStdErr string + desc string + dbProtocol string + execer *fakeExec + stderr []byte + wantBin string + wantStdErr string + getDatabaseFunc GetDatabaseFunc }{ { desc: "converts access denied to helpful message", @@ -1041,7 +1061,14 @@ func TestConvertCommandError(t *testing.T) { WithNoTLS(), WithExecer(tt.execer), } - c := NewCmdBuilder(tc, profile, database, "root", opts...) + + getDatabaseFunc := tt.getDatabaseFunc + if getDatabaseFunc == nil { + getDatabaseFunc = getDatabaseFuncWithError + } + + c, err := NewCmdBuilder(tc, profile, database, "root", getDatabaseFunc, opts...) + require.NoError(t, err) c.uid = utils.NewFakeUID() cmd, err := c.GetConnectCommand(context.Background()) @@ -1060,8 +1087,8 @@ func TestConvertCommandError(t *testing.T) { } } -func withMongoDBAtlasDatabase() ConnectCommandFunc { - return WithGetDatabaseFunc(func(context.Context, *client.TeleportClient, string) (types.Database, error) { +func withMongoDBAtlasDatabase() GetDatabaseFunc { + return func(context.Context, *client.TeleportClient, string) (types.Database, error) { db, err := types.NewDatabaseV3( types.Metadata{ Name: "mongodb-atlas", @@ -1072,11 +1099,11 @@ func withMongoDBAtlasDatabase() ConnectCommandFunc { }, ) return db, trace.Wrap(err) - }) + } } -func withDocumentDBDatabase() ConnectCommandFunc { - return WithGetDatabaseFunc(func(context.Context, *client.TeleportClient, string) (types.Database, error) { +func withDocumentDBDatabase() GetDatabaseFunc { + return func(context.Context, *client.TeleportClient, string) (types.Database, error) { db, err := types.NewDatabaseV3( types.Metadata{ Name: "docdb", @@ -1087,5 +1114,28 @@ func withDocumentDBDatabase() ConnectCommandFunc { }, ) return db, trace.Wrap(err) - }) + } +} + +func withSpannerDatabase(gcp types.GCPCloudSQL) GetDatabaseFunc { + return func(context.Context, *client.TeleportClient, string) (types.Database, error) { + db, err := types.NewDatabaseV3( + types.Metadata{ + Name: "docdb", + }, + types.DatabaseSpecV3{ + Protocol: types.DatabaseTypeSpanner, + URI: "spanner.googleapis.com:443", + GCP: gcp, + }, + ) + return db, trace.Wrap(err) + } +} + +// getDatabaseFuncWithError provides a non-nil function that returns error when +// retrieving the database. This can be used in tests that don't retrieve +// databases. +func getDatabaseFuncWithError(ctx context.Context, tc *client.TeleportClient, s string) (types.Database, error) { + return nil, trace.NotImplemented("unexpected call to getDatabase function") } diff --git a/lib/client/db/dbcmd/exec_test.go b/lib/client/db/dbcmd/exec_test.go index b8885f4ea03d2..c5504792522a9 100644 --- a/lib/client/db/dbcmd/exec_test.go +++ b/lib/client/db/dbcmd/exec_test.go @@ -58,11 +58,12 @@ func TestCLICommandBuilderGetExecCommand(t *testing.T) { } tests := []struct { - name string - opts []ConnectCommandFunc - protocol string - cmd []string - wantErr bool + name string + opts []ConnectCommandFunc + protocol string + cmd []string + wantErr bool + getDatabaseFunc GetDatabaseFunc }{ { name: "not authenticated tunnel", @@ -105,7 +106,13 @@ func TestCLICommandBuilderGetExecCommand(t *testing.T) { WithExecer(fakeExec), }, tt.opts...) - c := NewCmdBuilder(tc, profile, database, "root", opts...) + getDatabaseFunc := tt.getDatabaseFunc + if getDatabaseFunc == nil { + getDatabaseFunc = getDatabaseFuncWithError + } + + c, err := NewCmdBuilder(tc, profile, database, "root", getDatabaseFunc, opts...) + require.NoError(t, err) c.uid = utils.NewFakeUID() got, err := c.GetExecCommand(context.Background(), "select 1") if tt.wantErr { diff --git a/lib/teleterm/clusters/cluster_databases.go b/lib/teleterm/clusters/cluster_databases.go index 7260598213529..33d0e3f200563 100644 --- a/lib/teleterm/clusters/cluster_databases.go +++ b/lib/teleterm/clusters/cluster_databases.go @@ -141,7 +141,7 @@ type GetDatabasesResponse struct { // NewDBCLICmdBuilder creates a dbcmd.CLICommandBuilder with provided cluster, // db route, and options. -func NewDBCLICmdBuilder(cluster *Cluster, routeToDb tlsca.RouteToDatabase, options ...dbcmd.ConnectCommandFunc) *dbcmd.CLICommandBuilder { +func NewDBCLICmdBuilder(cluster *Cluster, routeToDb tlsca.RouteToDatabase, getDatabaseFunc dbcmd.GetDatabaseFunc, options ...dbcmd.ConnectCommandFunc) (*dbcmd.CLICommandBuilder, error) { return dbcmd.NewCmdBuilder( cluster.clusterClient, &cluster.status, @@ -153,6 +153,7 @@ func NewDBCLICmdBuilder(cluster *Cluster, routeToDb tlsca.RouteToDatabase, optio // generating correct CA paths. We use dbcmd.WithNoTLS here which means that the CA paths aren't // included in the returned CLI command. cluster.Name, + getDatabaseFunc, options..., ) } diff --git a/lib/teleterm/cmd/db.go b/lib/teleterm/cmd/db.go index 144d386224c38..8b0f121a96970 100644 --- a/lib/teleterm/cmd/db.go +++ b/lib/teleterm/cmd/db.go @@ -65,6 +65,13 @@ func newDBCLICommandWithExecer(ctx context.Context, cluster *clusters.Cluster, g var getDatabaseError error var database types.Database + getDatabaseFunc := func(ctx context.Context, _ *client.TeleportClient, _ string) (types.Database, error) { + getDatabaseOnce.Do(func() { + database, getDatabaseError = cluster.GetDatabase(ctx, authClient, gateway.TargetURI()) + }) + return database, trace.Wrap(getDatabaseError) + } + opts := []dbcmd.ConnectCommandFunc{ dbcmd.WithLogger(gateway.Log()), dbcmd.WithLocalProxy(gateway.LocalAddress(), gateway.LocalPortInt(), ""), @@ -72,12 +79,6 @@ func newDBCLICommandWithExecer(ctx context.Context, cluster *clusters.Cluster, g dbcmd.WithTolerateMissingCLIClient(), dbcmd.WithExecer(execer), dbcmd.WithOracleOpts(true /* can use TCP */, true /* has TCP servers */), - dbcmd.WithGetDatabaseFunc(func(ctx context.Context, _ *client.TeleportClient, _ string) (types.Database, error) { - getDatabaseOnce.Do(func() { - database, getDatabaseError = cluster.GetDatabase(ctx, authClient, gateway.TargetURI()) - }) - return database, trace.Wrap(getDatabaseError) - }), } switch gateway.Protocol() { @@ -91,12 +92,20 @@ func newDBCLICommandWithExecer(ctx context.Context, cluster *clusters.Cluster, g previewOpts := append(opts, dbcmd.WithPrintFormat()) - execCmd, err := clusters.NewDBCLICmdBuilder(cluster, routeToDb, opts...).GetConnectCommand(ctx) + execCmdBuilder, err := clusters.NewDBCLICmdBuilder(cluster, routeToDb, getDatabaseFunc, opts...) + if err != nil { + return Cmds{}, trace.Wrap(err) + } + execCmd, err := execCmdBuilder.GetConnectCommand(ctx) if err != nil { return Cmds{}, trace.Wrap(err) } - previewCmd, err := clusters.NewDBCLICmdBuilder(cluster, routeToDb, previewOpts...).GetConnectCommand(ctx) + previewCmdBuilder, err := clusters.NewDBCLICmdBuilder(cluster, routeToDb, getDatabaseFunc, previewOpts...) + if err != nil { + return Cmds{}, trace.Wrap(err) + } + previewCmd, err := previewCmdBuilder.GetConnectCommand(ctx) if err != nil { return Cmds{}, trace.Wrap(err) } diff --git a/lib/teleterm/cmd/db_test.go b/lib/teleterm/cmd/db_test.go index cd165b850cdc4..b239a4f1f3cec 100644 --- a/lib/teleterm/cmd/db_test.go +++ b/lib/teleterm/cmd/db_test.go @@ -72,14 +72,6 @@ func (m fakeDatabaseGateway) LocalPortInt() int { return 8888 } func (m fakeDatabaseGateway) LocalPort() string { return "8888" } func TestNewDBCLICommand(t *testing.T) { - // TODO mock other types - authClient := &mockAuthClient{ - database: &types.DatabaseV3{ - Spec: types.DatabaseSpecV3{ - Protocol: types.DatabaseProtocolMongoDB, - }, - }, - } testCases := []struct { name string @@ -87,30 +79,55 @@ func TestNewDBCLICommand(t *testing.T) { argsCount int protocol string checkCmds func(*testing.T, fakeDatabaseGateway, Cmds) + database *types.DatabaseV3 }{ { name: "empty name", protocol: defaults.ProtocolMongoDB, targetSubresourceName: "", checkCmds: checkMongoCmds, + database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseProtocolMongoDB, + }, + }, }, { name: "with name", protocol: defaults.ProtocolMongoDB, targetSubresourceName: "bar", checkCmds: checkMongoCmds, + database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseProtocolMongoDB, + }, + }, }, { name: "custom handling of DynamoDB does not blow up", targetSubresourceName: "bar", protocol: defaults.ProtocolDynamoDB, checkCmds: checkArgsNotEmpty, + database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseTypeDynamoDB, + }, + }, }, { name: "custom handling of Spanner does not blow up", targetSubresourceName: "bar", protocol: defaults.ProtocolSpanner, checkCmds: checkArgsNotEmpty, + database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseTypeSpanner, + GCP: types.GCPCloudSQL{ + ProjectID: "proj", + InstanceID: "inst", + }, + }, + }, }, } @@ -126,6 +143,7 @@ func TestNewDBCLICommand(t *testing.T) { protocol: tc.protocol, } + authClient := &mockAuthClient{database: tc.database} cmds, err := newDBCLICommandWithExecer(context.Background(), &cluster, mockGateway, fakeExec{}, authClient) require.NoError(t, err) diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index c9512ded10c23..10e8d2506e4f3 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -531,11 +531,14 @@ func onDatabaseConfig(cf *CLIConf) error { format := strings.ToLower(cf.Format) switch format { case dbFormatCommand: - cmd, err := dbcmd.NewCmdBuilder(tc, profile, *database, rootCluster, + cb, err := dbcmd.NewCmdBuilder(tc, profile, *database, rootCluster, getDatabase, dbcmd.WithPrintFormat(), dbcmd.WithLogger(log), - dbcmd.WithGetDatabaseFunc(getDatabase), - ).GetConnectCommand(cf.Context) + ) + if err != nil { + return trace.Wrap(err) + } + cmd, err := cb.GetConnectCommand(cf.Context) if err != nil { return trace.Wrap(err) } @@ -602,7 +605,10 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, requires *dbLocalProxyRequirement, ) ([]dbcmd.ConnectCommandFunc, error) { if !requires.localProxy { - return nil, nil + // Even when local proxy is not required, we need to build the base + // options for database command. + baseOpts, err := makeDatabaseCommandOptions(ctx, tc, dbInfo) + return baseOpts, trace.Wrap(err) } if requires.tunnel { log.Debugf("Starting local proxy tunnel because: %v", strings.Join(requires.tunnelReasons, ", ")) @@ -782,7 +788,10 @@ func onDatabaseConnect(cf *CLIConf) error { return trace.Wrap(err) } - bb := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootClusterName, opts...) + bb, err := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootClusterName, dbInfo.getDatabaseForDBCmd, opts...) + if err != nil { + return trace.Wrap(err) + } cmd, err := bb.GetConnectCommand(cf.Context) if err != nil { return trace.Wrap(err) @@ -798,7 +807,7 @@ func onDatabaseConnect(cf *CLIConf) error { peakStderr := utils.NewCaptureNBytesWriter(dbcmd.PeakStderrSize) cmd.Stderr = io.MultiWriter(os.Stderr, peakStderr) - err = cmd.Run() + err = cf.RunCommand(cmd) if err != nil { return dbcmd.ConvertCommandError(cmd, err, string(peakStderr.Bytes())) } diff --git a/tool/tsh/common/db_exec.go b/tool/tsh/common/db_exec.go index 63ff1575bde26..843acfb767308 100644 --- a/tool/tsh/common/db_exec.go +++ b/tool/tsh/common/db_exec.go @@ -516,8 +516,11 @@ func (m *databaseExecCommandMaker) makeCommand(ctx context.Context, dbInfo *data if err != nil { return nil, trace.Wrap(err) } - return dbcmd.NewCmdBuilder(m.tc, m.profile, dbInfo.RouteToDatabase, m.rootCluster, opts...). - GetExecCommand(ctx, command) + cb, err := dbcmd.NewCmdBuilder(m.tc, m.profile, dbInfo.RouteToDatabase, m.rootCluster, dbInfo.getDatabaseForDBCmd, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + return cb.GetExecCommand(ctx, command) } // ensureEachDatabase ensures one to one mapping between the provided database diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index 8a65cc2f2b495..743000f4d9e66 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -24,13 +24,16 @@ import ( "encoding/pem" "fmt" "os" + "os/exec" "path/filepath" "strings" + "sync/atomic" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" @@ -53,6 +56,7 @@ import ( dbcommon "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + testserver "github.com/gravitational/teleport/tool/teleport/testenv" ) func registerFakeEnterpriseDBEngines(t *testing.T) { @@ -2011,3 +2015,67 @@ func Test_shouldRetryGetDatabaseUsingSearchAsRoles(t *testing.T) { }) } } + +// TestMongoDBSeparatePortCommandError given a MongoDB database with cluster +// using separate port mode ensures `tsh` generates the connect command without +// errors. +// +// See https://github.com/gravitational/teleport/issues/47895 +func TestMongoDBSeparatePortCommandError(t *testing.T) { + t.Parallel() + + connector := mockConnector(t) + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetDatabaseUsers([]string{"admin"}) + alice.SetDatabaseNames([]string{"default"}) + alice.SetRoles([]string{"access"}) + + process, err := testserver.NewTeleportProcess( + t.TempDir(), + testserver.WithClusterName("root"), + testserver.WithBootstrap(connector, alice), + testserver.WithConfig(func(cfg *servicecfg.Config) { + mongoPublicAddr := localListenerAddr() + cfg.Proxy.MongoAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: mongoPublicAddr} + cfg.Proxy.MongoPublicAddrs = []utils.NetAddr{{AddrNetwork: "tcp", Addr: mongoPublicAddr}} + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + { + Name: "mongo", + Protocol: defaults.ProtocolMongoDB, + URI: "external-mongo:27017", + }, + } + }), + ) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, process.Close()) + assert.NoError(t, process.Wait()) + }) + + tshHome, _ := mustLogin(t, process, alice, connector.GetName()) + + // cmdExecuted tracks that the MongoDB command was executed. + var cmdExecuted atomic.Bool + + // noopCmdRunner is a command runner that does nothing. For this test, we + // only need to ensure the command is generated without actually validating + // it. This task should be handled in the dbcmd package. + noopCmdRunner := func(_ *exec.Cmd) error { + cmdExecuted.Store(true) + return nil + } + + err = Run(context.Background(), []string{ + "db", + "connect", + "mongo", + "--insecure", + "--db-name=test", + "--db-user=alice", + }, setHomePath(tshHome), setCmdRunner(noopCmdRunner)) + require.NoError(t, err) + require.True(t, cmdExecuted.Load(), "expected the MongoDB command to have executed") +} diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 3f2c1bc7661d8..6ffcdc72396c8 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -240,9 +240,11 @@ func onProxyCommandDB(cf *CLIConf) error { return trace.Wrap(err) } - commands, err := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootCluster, - opts..., - ).GetConnectCommandAlternatives(cf.Context) + cb, err := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootCluster, dbInfo.getDatabaseForDBCmd, opts...) + if err != nil { + return trace.Wrap(err) + } + commands, err := cb.GetConnectCommandAlternatives(cf.Context) if err != nil { return trace.Wrap(err) } @@ -314,30 +316,14 @@ func makeDatabaseCommandOptions(ctx context.Context, tc *libclient.TeleportClien var err error opts := append([]dbcmd.ConnectCommandFunc{ dbcmd.WithLogger(log), - dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd), }, extraOpts...) if opts, err = maybeAddDBUserPassword(ctx, tc, dbInfo, opts); err != nil { return nil, trace.Wrap(err) } - if opts, err = maybeAddGCPMetadata(ctx, tc, dbInfo, opts); err != nil { - return nil, trace.Wrap(err) - } return maybeAddOracleOptions(ctx, tc, dbInfo, opts), nil } -func maybeAddGCPMetadata(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { - if !requiresGCPMetadata(dbInfo.Protocol) { - return opts, nil - } - db, err := dbInfo.GetDatabase(ctx, tc) - if err != nil { - return nil, trace.Wrap(err) - } - gcp := db.GetGCP() - return append(opts, dbcmd.WithGCP(gcp)), nil -} - func maybeAddGCPMetadataTplArgs(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, templateArgs map[string]any) { if !requiresGCPMetadata(dbInfo.Protocol) { return