diff --git a/tool/tsh/db.go b/tool/tsh/db.go index f7918b7bfe410..d716ec3fd8f3c 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -806,7 +806,7 @@ func onDatabaseConnect(cf *CLIConf) error { // is active in profile and no labels or predicate query are given. // Otherwise, the ListDatabases endpoint is called. func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { - haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != "" + haveSelectors := tc.DatabaseService != "" || len(tc.Labels) > 0 || tc.PredicateExpression != "" if !haveSelectors { // if selectors are given, we might incur an extra ListDatabases API // call here to match against an active database. diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 4f37a0442213b..3f1168d82b694 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -58,7 +59,7 @@ func TestTshDB(t *testing.T) { t.Run("Login", testDatabaseLogin) t.Run("List", testListDatabase) t.Run("FilterActiveDatabases", testFilterActiveDatabases) - t.Run("GetDatabase", testGetDatabase) + t.Run("DatabaseInfo", testDatabaseInfo) } // testDatabaseLogin tests "tsh db login" command and verifies "tsh db @@ -931,7 +932,7 @@ func testFilterActiveDatabases(t *testing.T) { } } -func testGetDatabase(t *testing.T) { +func testDatabaseInfo(t *testing.T) { t.Parallel() alice, err := types.NewUser("alice@example.com") require.NoError(t, err) @@ -948,6 +949,13 @@ func testGetDatabase(t *testing.T) { StaticLabels: map[string]string{ "env": "local", }, + }, { + Name: "postgres-2", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + "env": "local", + }, }, { Name: "mysql", Protocol: defaults.ProtocolMySQL, @@ -996,71 +1004,102 @@ func testGetDatabase(t *testing.T) { tmpHomePath, _ := mustLogin(t, s) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - for _, db := range databases { - require.NotEmpty(t, db.Name) - require.NotEmpty(t, db.Protocol) - route := tlsca.RouteToDatabase{ - ServiceName: db.Name, - Protocol: db.Protocol, - Username: defaultDBUser, - Database: defaultDBName, - } - t.Run(route.ServiceName, func(t *testing.T) { - t.Run("with active db cert", func(t *testing.T) { - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, &route) - require.NoError(t, err) - require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - if route.Protocol == defaults.ProtocolDynamoDB { - // v13 specific. We remove the dynamodb schema name from the route since it's not supported. - require.Equal(t, route.ServiceName, dbInfo.ServiceName) - require.Equal(t, route.Protocol, dbInfo.Protocol) - require.Equal(t, route.Username, dbInfo.Username) - } else { - require.Equal(t, route, dbInfo.RouteToDatabase) - } - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") - }) - t.Run("without active db cert", func(t *testing.T) { - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - DatabaseService: route.ServiceName, - DatabaseUser: route.Username, - DatabaseName: route.Database, - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, nil) - require.NoError(t, err) - require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - if route.Protocol == defaults.ProtocolDynamoDB { - // v13 specific. We remove the dynamodb schema name from the route since it's not supported. - require.Equal(t, route.ServiceName, dbInfo.ServiceName) - require.Equal(t, route.Protocol, dbInfo.Protocol) - require.Equal(t, route.Username, dbInfo.Username) - } else { - require.Equal(t, route, dbInfo.RouteToDatabase) - } - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "cached database should be the same") + t.Run("newDatabaseInfo", func(t *testing.T) { + for _, db := range databases { + require.NotEmpty(t, db.Name) + require.NotEmpty(t, db.Protocol) + route := tlsca.RouteToDatabase{ + ServiceName: db.Name, + Protocol: db.Protocol, + Username: defaultDBUser, + Database: defaultDBName, + } + t.Run(route.ServiceName, func(t *testing.T) { + t.Run("with active db cert", func(t *testing.T) { + cf := &CLIConf{ + Context: ctx, + TracingProvider: tracing.NoopProvider(), + HomePath: tmpHomePath, + tracer: tracing.NoopTracer(teleport.ComponentTSH), + } + tc, err := makeClient(cf) + require.NoError(t, err) + dbInfo, err := newDatabaseInfo(cf, tc, &route) + require.NoError(t, err) + require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched") + db, err := dbInfo.GetDatabase(cf, tc) + require.NoError(t, err) + if route.Protocol == defaults.ProtocolDynamoDB { + // v13 specific. We remove the dynamodb schema name from the route since it's not supported. + require.Equal(t, route.ServiceName, dbInfo.ServiceName) + require.Equal(t, route.Protocol, dbInfo.Protocol) + require.Equal(t, route.Username, dbInfo.Username) + } else { + require.Equal(t, route, dbInfo.RouteToDatabase) + } + require.Equal(t, route.ServiceName, db.GetName()) + require.Equal(t, route.Protocol, db.GetProtocol()) + require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") + }) + t.Run("without active db cert", func(t *testing.T) { + cf := &CLIConf{ + Context: ctx, + TracingProvider: tracing.NoopProvider(), + HomePath: tmpHomePath, + tracer: tracing.NoopTracer(teleport.ComponentTSH), + DatabaseService: route.ServiceName, + DatabaseUser: route.Username, + DatabaseName: route.Database, + } + tc, err := makeClient(cf) + require.NoError(t, err) + dbInfo, err := newDatabaseInfo(cf, tc, nil) + require.NoError(t, err) + require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched") + db, err := dbInfo.GetDatabase(cf, tc) + require.NoError(t, err) + if route.Protocol == defaults.ProtocolDynamoDB { + // v13 specific. We remove the dynamodb schema name from the route since it's not supported. + require.Equal(t, route.ServiceName, dbInfo.ServiceName) + require.Equal(t, route.Protocol, dbInfo.Protocol) + require.Equal(t, route.Username, dbInfo.Username) + } else { + require.Equal(t, route, dbInfo.RouteToDatabase) + } + require.Equal(t, route.ServiceName, db.GetName()) + require.Equal(t, route.Protocol, db.GetProtocol()) + require.Equal(t, dbInfo.database, db, "cached database should be the same") + }) }) - }) - } + } + }) + t.Run("getDatabaseInfo", func(t *testing.T) { + // login to "postgres-2" db. + err = Run(ctx, []string{"db", "login", "postgres-2"}, setHomePath(tmpHomePath)) + require.NoError(t, err) + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + // select the other db, "postgres", which was not logged into. + DatabaseService: "postgres", + // v13 specific: set the db name/username because it won't be + // set by default until v14+. + DatabaseUser: defaultDBUser, + DatabaseName: defaultDBName, + } + tc, err := makeClient(cf) + require.NoError(t, err) + dbInfo, err := getDatabaseInfo(cf, tc) + require.NoError(t, err) + require.NotNil(t, dbInfo) + // verify that the active login route for "postgres-2" was not used + // instead of fetching info for the "postgres" db. + require.Equal(t, "postgres", dbInfo.ServiceName) + require.Equal(t, defaults.ProtocolPostgres, dbInfo.Protocol) + require.Equal(t, defaultDBUser, dbInfo.Username) + require.Equal(t, defaultDBName, dbInfo.Database) + require.NotNil(t, dbInfo.database) + }) } func TestResourceSelectorsFormatting(t *testing.T) {