diff --git a/tool/tsh/db.go b/tool/tsh/db.go index 13900befdcb31..a857d6d6816a7 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -852,6 +852,23 @@ func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteT // checkAndSetPrincipalDefaults checks the db route (schema) name and username, // and sets them to defaults if necessary. func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } + + // if either user or db name isn't given as a cli flag, try to populate + // user/db name from an active db cert. + if cf.DatabaseUser == "" || cf.DatabaseName == "" { + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + if route, ok := findActiveDatabase(d.ServiceName, routes); ok { + d.Username = route.Username + d.Database = route.Database + } + } if cf.DatabaseUser != "" { d.Username = cf.DatabaseUser } @@ -1228,16 +1245,11 @@ func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activ } prefix := tc.DatabaseService if len(tc.Labels) == 0 && tc.PredicateExpression == "" { - if prefix == "" && len(activeRoutes) == 1 { - return activeRoutes, nil, nil - } // when we have a name but don't have label or predicate query, look for // a route that matches the name exactly to maybe avoid calling // ListDatabases API below. - for _, route := range activeRoutes { - if route.ServiceName == prefix { - return []tlsca.RouteToDatabase{route}, nil, nil - } + if route, ok := findActiveDatabase(prefix, activeRoutes); ok { + return []tlsca.RouteToDatabase{route}, nil, nil } } @@ -1274,6 +1286,20 @@ func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activ return selectedRoutes, activeDBs, nil } +// findActiveDatabase returns a database route and a bool indicating whether +// the route was found. +func findActiveDatabase(name string, activeRoutes []tlsca.RouteToDatabase) (tlsca.RouteToDatabase, bool) { + if name == "" && len(activeRoutes) == 1 { + return activeRoutes[0], true + } + for _, r := range activeRoutes { + if r.ServiceName == name { + return r, true + } + } + return tlsca.RouteToDatabase{}, false +} + func formatDatabaseListCommand(clusterFlag string) string { if clusterFlag == "" { return "tsh db ls" diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 5090a223a82f4..c6f9109cfad85 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -973,8 +973,11 @@ func testDatabaseInfo(t *testing.T) { require.NoError(t, err) defaultDBUser := "admin" defaultDBName := "default" - alice.SetDatabaseUsers([]string{defaultDBUser}) - alice.SetDatabaseNames([]string{defaultDBName}) + // add multiple allowed db names/users, to prevent default selection. + // these tests should use the db name/username from either cli flag or + // active cert only. + alice.SetDatabaseUsers([]string{defaultDBUser, "foo"}) + alice.SetDatabaseNames([]string{defaultDBName, "bar"}) alice.SetRoles([]string{"access"}) databases := []servicecfg.Database{ { @@ -1050,7 +1053,7 @@ func testDatabaseInfo(t *testing.T) { Database: defaultDBName, } t.Run(route.ServiceName, func(t *testing.T) { - t.Run("with active db cert", func(t *testing.T) { + t.Run("with route", func(t *testing.T) { cf := &CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), @@ -1076,15 +1079,18 @@ func testDatabaseInfo(t *testing.T) { require.Equal(t, route.Protocol, db.GetProtocol()) require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") }) - t.Run("without active db cert", func(t *testing.T) { + t.Run("without route", func(t *testing.T) { + err = Run(ctx, []string{"db", "login", route.ServiceName, + "--db-user", route.Username, + "--db-name", route.Database, + }, setHomePath(tmpHomePath)) + require.NoError(t, err) cf := &CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), HomePath: tmpHomePath, tracer: tracing.NoopTracer(teleport.ComponentTSH), DatabaseService: route.ServiceName, - DatabaseUser: route.Username, - DatabaseName: route.Database, } tc, err := makeClient(cf) require.NoError(t, err)