diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index e879ba3b3157b..a6b96d7ea0817 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -863,6 +863,23 @@ func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteT // checkAndSetPrincipalDefaults checks the db route (schema) name and username, // and sets them to defaults if necessary. func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } + + // if either user or db name isn't given as a cli flag, try to populate + // user/db name from an active db cert. + if cf.DatabaseUser == "" || cf.DatabaseName == "" { + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + if route, ok := findActiveDatabase(d.ServiceName, routes); ok { + d.Username = route.Username + d.Database = route.Database + } + } if cf.DatabaseUser != "" { d.Username = cf.DatabaseUser } @@ -883,11 +900,6 @@ func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.Tele return nil } - profile, err := tc.ProfileStatus() - if err != nil { - return trace.Wrap(err) - } - var proxy *client.ProxyClient err = client.RetryWithRelogin(cf.Context, tc, func() error { proxy, err = tc.ConnectToProxy(cf.Context) @@ -1340,16 +1352,11 @@ func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activ } prefix := tc.DatabaseService if len(tc.Labels) == 0 && tc.PredicateExpression == "" { - if prefix == "" && len(activeRoutes) == 1 { - return activeRoutes, nil, nil - } // when we have a name but don't have label or predicate query, look for // a route that matches the name exactly to maybe avoid calling // ListDatabases API below. - for _, route := range activeRoutes { - if route.ServiceName == prefix { - return []tlsca.RouteToDatabase{route}, nil, nil - } + if route, ok := findActiveDatabase(prefix, activeRoutes); ok { + return []tlsca.RouteToDatabase{route}, nil, nil } } @@ -1386,6 +1393,20 @@ func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activ return selectedRoutes, activeDBs, nil } +// findActiveDatabase returns a database route and a bool indicating whether +// the route was found. +func findActiveDatabase(name string, activeRoutes []tlsca.RouteToDatabase) (tlsca.RouteToDatabase, bool) { + if name == "" && len(activeRoutes) == 1 { + return activeRoutes[0], true + } + for _, r := range activeRoutes { + if r.ServiceName == name { + return r, true + } + } + return tlsca.RouteToDatabase{}, false +} + func formatDatabaseListCommand(clusterFlag string) string { if clusterFlag == "" { return "tsh db ls" diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index 0657233a11250..85a050aadbd49 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -1100,8 +1100,11 @@ func testDatabaseInfo(t *testing.T) { require.NoError(t, err) defaultDBUser := "admin" defaultDBName := "default" - alice.SetDatabaseUsers([]string{defaultDBUser}) - alice.SetDatabaseNames([]string{defaultDBName}) + // add multiple allowed db names/users, to prevent default selection. + // these tests should use the db name/username from either cli flag or + // active cert only. + alice.SetDatabaseUsers([]string{defaultDBUser, "foo"}) + alice.SetDatabaseNames([]string{defaultDBName, "bar"}) alice.SetRoles([]string{"access"}) databases := []servicecfg.Database{ { @@ -1177,7 +1180,7 @@ func testDatabaseInfo(t *testing.T) { Database: defaultDBName, } t.Run(route.ServiceName, func(t *testing.T) { - t.Run("with active db cert", func(t *testing.T) { + t.Run("with route", func(t *testing.T) { cf := &CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), @@ -1196,15 +1199,18 @@ func testDatabaseInfo(t *testing.T) { require.Equal(t, route.Protocol, db.GetProtocol()) require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") }) - t.Run("without active db cert", func(t *testing.T) { + t.Run("without route", func(t *testing.T) { + err = Run(ctx, []string{"db", "login", route.ServiceName, + "--db-user", route.Username, + "--db-name", route.Database, + }, setHomePath(tmpHomePath)) + require.NoError(t, err) cf := &CLIConf{ Context: ctx, TracingProvider: tracing.NoopProvider(), HomePath: tmpHomePath, tracer: tracing.NoopTracer(teleport.ComponentTSH), DatabaseService: route.ServiceName, - DatabaseUser: route.Username, - DatabaseName: route.Database, } tc, err := makeClient(cf) require.NoError(t, err)