diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index ebcb96dc8422a..e81f023f06188 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -817,7 +817,7 @@ func onDatabaseConnect(cf *CLIConf) error { // is active in profile and no labels or predicate query are given. // Otherwise, the ListDatabases endpoint is called. func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { - haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != "" + haveSelectors := tc.DatabaseService != "" || len(tc.Labels) > 0 || tc.PredicateExpression != "" if !haveSelectors { // if selectors are given, we might incur an extra ListDatabases API // call here to match against an active database. diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index 599451b756667..fc8b9f2a5e6da 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -60,7 +60,7 @@ func TestTshDB(t *testing.T) { t.Run("Login", testDatabaseLogin) t.Run("List", testListDatabase) t.Run("FilterActiveDatabases", testFilterActiveDatabases) - t.Run("GetDatabase", testGetDatabase) + t.Run("DatabaseInfo", testDatabaseInfo) } // testDatabaseLogin tests "tsh db login" command and verifies "tsh db @@ -1059,7 +1059,7 @@ func testFilterActiveDatabases(t *testing.T) { } } -func testGetDatabase(t *testing.T) { +func testDatabaseInfo(t *testing.T) { t.Parallel() alice, err := types.NewUser("alice@example.com") require.NoError(t, err) @@ -1076,6 +1076,13 @@ func testGetDatabase(t *testing.T) { StaticLabels: map[string]string{ "env": "local", }, + }, { + Name: "postgres-2", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + "env": "local", + }, }, { Name: "mysql", Protocol: defaults.ProtocolMySQL, @@ -1124,59 +1131,84 @@ func testGetDatabase(t *testing.T) { tmpHomePath, _ := mustLogin(t, s) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - for _, db := range databases { - require.NotEmpty(t, db.Name) - require.NotEmpty(t, db.Protocol) - route := tlsca.RouteToDatabase{ - ServiceName: db.Name, - Protocol: db.Protocol, - Username: defaultDBUser, - Database: defaultDBName, - } - t.Run(route.ServiceName, func(t *testing.T) { - t.Run("with active db cert", func(t *testing.T) { - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - tracer: tracing.NoopTracer(teleport.ComponentTSH), - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, &route) - require.NoError(t, err) - require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - require.Equal(t, route, dbInfo.RouteToDatabase) - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") - }) - t.Run("without active db cert", func(t *testing.T) { - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - tracer: tracing.NoopTracer(teleport.ComponentTSH), - DatabaseService: route.ServiceName, - DatabaseUser: route.Username, - DatabaseName: route.Database, - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, nil) - require.NoError(t, err) - require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - require.Equal(t, route, dbInfo.RouteToDatabase) - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "cached database should be the same") + t.Run("newDatabaseInfo", func(t *testing.T) { + for _, db := range databases { + require.NotEmpty(t, db.Name) + require.NotEmpty(t, db.Protocol) + route := tlsca.RouteToDatabase{ + ServiceName: db.Name, + Protocol: db.Protocol, + Username: defaultDBUser, + Database: defaultDBName, + } + t.Run(route.ServiceName, func(t *testing.T) { + t.Run("with active db cert", func(t *testing.T) { + cf := &CLIConf{ + Context: ctx, + TracingProvider: tracing.NoopProvider(), + HomePath: tmpHomePath, + tracer: tracing.NoopTracer(teleport.ComponentTSH), + } + tc, err := makeClient(cf) + require.NoError(t, err) + dbInfo, err := newDatabaseInfo(cf, tc, &route) + require.NoError(t, err) + require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched") + db, err := dbInfo.GetDatabase(cf, tc) + require.NoError(t, err) + require.Equal(t, route, dbInfo.RouteToDatabase) + require.Equal(t, route.ServiceName, db.GetName()) + require.Equal(t, route.Protocol, db.GetProtocol()) + require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") + }) + t.Run("without active db cert", func(t *testing.T) { + cf := &CLIConf{ + Context: ctx, + TracingProvider: tracing.NoopProvider(), + HomePath: tmpHomePath, + tracer: tracing.NoopTracer(teleport.ComponentTSH), + DatabaseService: route.ServiceName, + DatabaseUser: route.Username, + DatabaseName: route.Database, + } + tc, err := makeClient(cf) + require.NoError(t, err) + dbInfo, err := newDatabaseInfo(cf, tc, nil) + require.NoError(t, err) + require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched") + db, err := dbInfo.GetDatabase(cf, tc) + require.NoError(t, err) + require.Equal(t, route, dbInfo.RouteToDatabase) + require.Equal(t, route.ServiceName, db.GetName()) + require.Equal(t, route.Protocol, db.GetProtocol()) + require.Equal(t, dbInfo.database, db, "cached database should be the same") + }) }) - }) - } + } + }) + t.Run("getDatabaseInfo", func(t *testing.T) { + // login to "postgres-2" db. + err = Run(ctx, []string{"db", "login", "postgres-2"}, setHomePath(tmpHomePath)) + require.NoError(t, err) + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + // select the other db, "postgres", which was not logged into. + DatabaseService: "postgres", + } + tc, err := makeClient(cf) + require.NoError(t, err) + dbInfo, err := getDatabaseInfo(cf, tc) + require.NoError(t, err) + require.NotNil(t, dbInfo) + // verify that the active login route for "postgres-2" was not used + // instead of fetching info for the "postgres" db. + require.Equal(t, "postgres", dbInfo.ServiceName) + require.Equal(t, defaults.ProtocolPostgres, dbInfo.Protocol) + require.Equal(t, defaultDBUser, dbInfo.Username) + require.Equal(t, defaultDBName, dbInfo.Database) + require.NotNil(t, dbInfo.database) + }) } func TestResourceSelectorsFormatting(t *testing.T) {