diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 11ee198f6999b..9332c5505867d 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -251,12 +251,20 @@ func onDatabaseLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := newDatabaseInfo(cf, tc, nil) + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + dbInfo, err := getDatabaseInfo(cf, tc, routes) if err != nil { return trace.Wrap(err) } - database, err := dbInfo.GetDatabase(cf, tc) + database, err := dbInfo.GetDatabase(cf.Context, tc) if err != nil { return trace.Wrap(err) } @@ -361,7 +369,7 @@ func onDatabaseLogout(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - logout, _, err := filterActiveDatabases(cf.Context, tc, activeRoutes) + databases, err := getDatabasesForLogout(cf, tc, activeRoutes) if err != nil { return trace.Wrap(err) } @@ -370,12 +378,12 @@ func onDatabaseLogout(cf *CLIConf) error { log.Info("Note: an identity file is in use (`-i ...`); will only update database config files.") } - for _, db := range logout { + for _, db := range databases { if err := databaseLogout(tc, db, profile.IsVirtual); err != nil { return trace.Wrap(err) } } - msg, err := makeLogoutMessage(cf, logout, activeRoutes) + msg, err := makeLogoutMessage(cf, databases, activeRoutes) if err != nil { return trace.Wrap(err) } @@ -387,21 +395,16 @@ func onDatabaseLogout(cf *CLIConf) error { // result of "tsh db logout". func makeLogoutMessage(cf *CLIConf, logout, activeRoutes []tlsca.RouteToDatabase) (string, error) { switch len(logout) { - case 0: - selectors := resourceSelectors{ - kind: "database", - name: cf.DatabaseService, - labels: cf.Labels, - query: cf.PredicateExpression, - } - if selectors.IsEmpty() { - return "", trace.NotFound("Not logged into any databases") - } - return "", trace.NotFound("Not logged into %v", selectors) case 1: return fmt.Sprintf("Logged out of database %v", logout[0].ServiceName), nil case len(activeRoutes): return "Logged out of all databases", nil + case 0: + selectors := newDatabaseResourceSelectors(cf) + if selectors.IsEmpty() { + return "", trace.NotFound("Not logged into any databases") + } + return "", trace.NotFound("Not logged into %s", selectors) default: names := make([]string, 0, len(logout)) for _, route := range logout { @@ -438,7 +441,15 @@ func onDatabaseEnv(cf *CLIConf) error { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf, tc) + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + database, err := pickActiveDatabase(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -496,7 +507,11 @@ func onDatabaseConfig(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf, tc) + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + database, err := pickActiveDatabase(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -706,7 +721,7 @@ func prepareLocalProxyOptions(arg *localProxyConfig) ([]alpnproxy.LocalProxyConf opts = append(opts, alpnproxy.WithCheckCertsNeeded()) case defaults.ProtocolMySQL: // To set correct MySQL server version DB proxy needs additional protocol. - db, err := arg.dbInfo.GetDatabase(arg.cf, arg.tc) + db, err := arg.dbInfo.GetDatabase(arg.cf.Context, arg.tc) if err != nil { return nil, trace.Wrap(err) } @@ -725,7 +740,11 @@ func onDatabaseConnect(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := getDatabaseInfo(cf, tc) + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + dbInfo, err := getDatabaseInfo(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -785,66 +804,44 @@ func onDatabaseConnect(cf *CLIConf) error { // getDatabaseInfo fetches information about the database from tsh profile if DB // is active in profile and no labels or predicate query are given. // Otherwise, the ListDatabases endpoint is called. -func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { - haveSelectors := tc.DatabaseService != "" || len(tc.Labels) > 0 || tc.PredicateExpression != "" - if !haveSelectors { - // if selectors are given, we might incur an extra ListDatabases API - // call here to match against an active database. - // So try to pick an active database only when we don't have - // selectors. - if route, err := pickActiveDatabase(cf, tc); err == nil { - return newDatabaseInfo(cf, tc, route) - } else if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - } - return newDatabaseInfo(cf, tc, nil) -} - -// newDatabaseInfo makes a new databaseInfo from the given route to the db. -// It checks the route and sets defaults as needed for protocol, db user, or db -// name. If the route is not given or the remote database is needed for setting -// a default, the database is retrieved by calling ListDatabases API and cached. -func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteToDatabase) (*databaseInfo, error) { - dbInfo := &databaseInfo{} - if route != nil { - dbInfo.RouteToDatabase = *route - // the only way we're going to have all this info populated is from an - // active cert. - if dbInfo.ServiceName != "" && dbInfo.Protocol != "" && - dbInfo.Username != "" && dbInfo.Database != "" { - return dbInfo, nil +func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, routes []tlsca.RouteToDatabase) (*databaseInfo, error) { + if route, err := maybePickActiveDatabase(cf, routes); err == nil && route != nil { + info := &databaseInfo{RouteToDatabase: *route, isActive: true} + return info, info.checkAndSetDefaults(cf, tc) + } else if err != nil { + if trace.IsNotFound(err) { + return nil, trace.BadParameter("please specify a database service by name, --labels, or --query") } + return nil, trace.Wrap(err) } - db, err := dbInfo.GetDatabase(cf, tc) + + db, err := getDatabaseByNameOrDiscoveredName(cf, tc, routes) if err != nil { return nil, trace.Wrap(err) } - // now ensure the route name and protocol matches the db we fetched. - dbInfo.ServiceName = db.GetName() - dbInfo.Protocol = db.GetProtocol() - return dbInfo, dbInfo.checkAndSetPrincipalDefaults(cf, tc, db) -} -// checkAndSetPrincipalDefaults checks the db route (schema) name and username, -// and sets them to defaults if necessary. -func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { - profile, err := tc.ProfileStatus() - if err != nil { - return trace.Wrap(err) + info := &databaseInfo{ + database: db, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: db.GetProtocol(), + }, } + // check for an active route now that we have the full db name. + if route, ok := findActiveDatabase(db.GetName(), routes); ok { + info.RouteToDatabase = route + info.isActive = true + } + if err := info.checkAndSetDefaults(cf, tc); err != nil { + return nil, trace.Wrap(err) + } + return info, nil +} - // if either user or db name isn't given as a cli flag, try to populate - // user/db name from an active db cert. - if cf.DatabaseUser == "" || cf.DatabaseName == "" { - routes, err := profile.DatabasesForCluster(tc.SiteName) - if err != nil { - return trace.Wrap(err) - } - if route, ok := findActiveDatabase(d.ServiceName, routes); ok { - d.Username = route.Username - d.Database = route.Database - } +// checkAndSetDefaults checks the db route, applies cli flags, and sets defaults. +func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClient) error { + if d.ServiceName == "" { + return trace.BadParameter("missing database service name") } if cf.DatabaseUser != "" { d.Username = cf.DatabaseUser @@ -852,16 +849,31 @@ func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.Tele if cf.DatabaseName != "" { d.Database = cf.DatabaseName } + db, err := d.GetDatabase(cf.Context, tc) + if err != nil { + if d.isActive && trace.IsNotFound(err) && strings.Contains(err.Error(), d.ServiceName) { + hint := formatStaleDBCert(cf.SiteName, d.ServiceName) + return trace.Wrap(err, hint) + } + return trace.Wrap(err) + } + // ensure the route protocol matches the db. + d.Protocol = db.GetProtocol() + + needDBUser := d.Username == "" && role.RequireDatabaseUserMatcher(d.Protocol) + needDBName := d.Database == "" && role.RequireDatabaseNameMatcher(d.Protocol) + if !needDBUser && !needDBName { + return nil + } + // If database has admin user defined, we're most likely using automatic // user provisioning so default to Teleport username unless database // username was provided explicitly. - if d.Username == "" && db.GetAdminUser() != "" { + if needDBUser && db.GetAdminUser() != "" { log.Debugf("Defaulting to Teleport username %q as database username.", tc.Username) d.Username = tc.Username + needDBUser = false } - // recheck to see if we can avoid fetching the roleset to set defaults. - needDBUser := d.Username == "" && role.RequireDatabaseUserMatcher(d.Protocol) - needDBName := d.Database == "" && role.RequireDatabaseNameMatcher(d.Protocol) if !needDBUser && !needDBName { return nil } @@ -876,6 +888,10 @@ func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.Tele } defer proxy.Close() + profile, err := tc.ProfileStatus() + if err != nil { + return trace.Wrap(err) + } checker, err := accessCheckerForRemoteCluster(cf.Context, profile, proxy, tc.SiteName) if err != nil { return trace.Wrap(err) @@ -908,66 +924,56 @@ type databaseInfo struct { // database corresponds to the db route and may be nil, so use GetDatabase // instead of accessing it directly. database types.Database + // isActive indicates an active database matched this db info. + isActive bool mu sync.Mutex } // GetDatabase returns the cached database or fetches it using the db route and // caches the result. -func (d *databaseInfo) GetDatabase(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { - if d.ServiceName == "" && cf.DatabaseService == "" && - len(tc.Labels) == 0 && tc.PredicateExpression == "" { - return nil, trace.BadParameter("specify a database service by name, --labels, or --query") - } +func (d *databaseInfo) GetDatabase(ctx context.Context, tc *client.TeleportClient) (types.Database, error) { d.mu.Lock() defer d.mu.Unlock() if d.database != nil { return d.database, nil } // holding mutex across the api call to avoid multiple redundant api calls. - var databases types.Databases - var err error - name := d.ServiceName - if name != "" { - databases, err = listDatabasesByName(cf.Context, tc, name) - } else { - name = cf.DatabaseService - // search by prefix if the db name comes from cli flag instead of cert. - databases, err = listDatabasesByPrefix(cf.Context, tc, name) - } + database, err := getDatabase(ctx, tc, d.ServiceName) if err != nil { return nil, trace.Wrap(err) } - db, err := chooseOneDatabase(cf, name, databases) - if err != nil { - return nil, trace.Wrap(err) - } - - d.database = db + d.database = database return d.database, nil } -// chooseOneDatabase is a helper func for GetDatabase that returns either the -// only database in a list of databases or returns a database that matches the -// nameOrPrefix exactly, otherwise an error. -func chooseOneDatabase(cf *CLIConf, nameOrPrefix string, databases types.Databases) (types.Database, error) { - if len(databases) == 1 { - return databases[0], nil - } - // Check if nameOrPrefix matches any database exactly and, if so, choose +// chooseOneDatabase is a helper func that returns either the only database in a +// list of databases or returns a database that matches the selector name +// or unambiguous discovered name exactly, otherwise an error. +func chooseOneDatabase(cf *CLIConf, databases types.Databases) (types.Database, error) { + selectors := newDatabaseResourceSelectors(cf) + // Check if the name matches any database exactly and, if so, choose // that database over any others. for _, db := range databases { - if db.GetName() == nameOrPrefix { + if db.GetName() == selectors.name { + log.Debugf("Selected database %q by exact name match", db.GetName()) return db, nil } } + // look for a single database with a matching discovered name label. + if dbs := findDatabasesByDiscoveredName(databases, selectors.name); len(dbs) > 0 { + names := make([]string, 0, len(dbs)) + for _, db := range dbs { + names = append(names, db.GetName()) + } + log.Debugf("Choosing amongst databases (%v) by discovered name", names) + databases = dbs + } + if len(databases) == 1 { + log.Debugf("Selected database %q", databases[0].GetName()) + return databases[0], nil + } // error - we need exactly one database. - selectors := resourceSelectors{ - kind: "database", - name: nameOrPrefix, - labels: cf.Labels, - query: cf.PredicateExpression, - } if len(databases) == 0 { return nil, trace.NotFound( "%v not found, use '%v' to see registered databases", selectors, @@ -977,52 +983,61 @@ func chooseOneDatabase(cf *CLIConf, nameOrPrefix string, databases types.Databas return nil, trace.BadParameter(errMsg) } -// listActiveDatabases lists databases that match active (logged in) databases. -func listActiveDatabases(ctx context.Context, tc *client.TeleportClient, routes []tlsca.RouteToDatabase) (types.Databases, error) { - names := make([]string, 0, len(routes)) - for _, r := range routes { - names = append(names, fmt.Sprintf("(name == %q)", r.ServiceName)) +// findDatabasesByDiscoveredName returns all databases that have a discovered +// name label that matches the given name. +func findDatabasesByDiscoveredName(databases types.Databases, name string) types.Databases { + var out types.Databases + for _, db := range databases { + discoveredName, ok := db.GetLabel(types.DiscoveredNameLabel) + if ok && discoveredName == name { + out = append(out, db) + } } - predicate := strings.Join(names, "||") - return listDatabasesWithPredicate(ctx, tc, predicate) + return out } -// listDatabasesByName lists database that match a given name. -func listDatabasesByName(ctx context.Context, tc *client.TeleportClient, name string) (types.Databases, error) { - predicate := fmt.Sprintf("name == %q", name) - return listDatabasesWithPredicate(ctx, tc, predicate) -} - -// makePrefixPredicate returns a predicate expression that matches resources -// by prefix name. -func makePrefixPredicate(prefix string) string { - if prefix == "" { - return "" +// getDatabase gets a database using its full name. +func getDatabase(ctx context.Context, tc *client.TeleportClient, name string) (types.Database, error) { + matchName := makeNamePredicate(name) + databases, err := listDatabasesWithPredicate(ctx, tc, matchName) + if err != nil { + return nil, trace.Wrap(err) + } + if len(databases) == 0 { + return nil, trace.NotFound("database %q not found among registered databases in cluster %v", name, tc.SiteName) } - return fmt.Sprintf(`hasPrefix(name, %q)`, prefix) + return databases[0], nil } -// listDatabasesByPrefix lists databases that match a given name prefix. -func listDatabasesByPrefix(ctx context.Context, tc *client.TeleportClient, prefix string) (types.Databases, error) { - predicate := makePrefixPredicate(prefix) - databases, err := listDatabasesWithPredicate(ctx, tc, predicate) - if err == nil || !utils.IsPredicateError(err) || predicate == "" { - return databases, trace.Wrap(err) - } - // predicate error from using hasPrefix expression. - // fallback to listing without the hasPrefix predicate and filtering - // on client side for backwards compatibility. - databases, err = listDatabasesWithPredicate(ctx, tc, "") +// getDatabaseByNameOrDiscoveredName fetches a database that unambiguously +// matches a given name or a discovered name label. +func getDatabaseByNameOrDiscoveredName(cf *CLIConf, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) (types.Database, error) { + predicate := makeDiscoveredNameOrNamePredicate(cf.DatabaseService) + databases, err := listDatabasesWithPredicate(cf.Context, tc, predicate) if err != nil { return nil, trace.Wrap(err) } + if activeDBs := filterActiveDatabases(activeRoutes, databases); len(activeDBs) > 0 { + names := make([]string, 0, len(activeDBs)) + for _, db := range activeDBs { + names = append(names, db.GetName()) + } + log.Debugf("Choosing a database amongst active databases (%v)", names) + // preferentially choose from active databases if any of them match. + return chooseOneDatabase(cf, activeDBs) + } + return chooseOneDatabase(cf, databases) +} + +func filterActiveDatabases(routes []tlsca.RouteToDatabase, databases types.Databases) types.Databases { + databasesByName := databases.ToMap() var out types.Databases - for _, db := range databases { - if strings.HasPrefix(db.GetName(), prefix) { + for _, route := range routes { + if db, ok := databasesByName[route.ServiceName]; ok { out = append(out, db) } } - return out, nil + return out } // listDatabasesWithPredicate is a helper func for listing databases using @@ -1032,10 +1047,12 @@ func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, var databases []types.Database err := client.RetryWithRelogin(ctx, tc, func() error { var err error + predicate := makePredicateConjunction(predicate, tc.PredicateExpression) + log.Debugf("Listing databases with predicate (%v) and labels %v", predicate, tc.Labels) databases, err = tc.ListDatabases(ctx, &proto.ListResourcesRequest{ Namespace: tc.Namespace, ResourceType: types.KindDatabaseServer, - PredicateExpression: combinePredicateExpressions(predicate, tc.PredicateExpression), + PredicateExpression: predicate, Labels: tc.Labels, }) return trace.Wrap(err) @@ -1043,9 +1060,43 @@ func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, return databases, trace.Wrap(err) } -// combinePredicateExpressions combines two predicate expressions into one +func makeDiscoveredNameOrNamePredicate(name string) string { + matchName := makeNamePredicate(name) + matchDiscoveredName := makeDiscoveredNamePredicate(name) + return makePredicateDisjunction(matchName, matchDiscoveredName) +} + +func makeDiscoveredNamePredicate(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return fmt.Sprintf(`labels[%q] == %q`, types.DiscoveredNameLabel, name) +} + +func makeNamePredicate(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return fmt.Sprintf(`name == %q`, name) +} + +// makePredicateConjunction combines two predicate expressions into one // expression as a conjunction (logical AND) of the expressions. -func combinePredicateExpressions(a, b string) string { +func makePredicateConjunction(a, b string) string { + return combinePredicateExpressions(a, b, "&&") +} + +// makePredicateDisjunction combines two predicate expressions into one +// expression as a disjunction (logical OR) of the expressions. +func makePredicateDisjunction(a, b string) string { + return combinePredicateExpressions(a, b, "||") +} + +// combinePredicateExpressions combines two predicate expressions into one +// expression with the given operator. +func combinePredicateExpressions(a, b, op string) string { a = strings.TrimSpace(a) b = strings.TrimSpace(b) switch { @@ -1056,7 +1107,7 @@ func combinePredicateExpressions(a, b string) string { case a == b: return a default: - return fmt.Sprintf("(%v) && (%v)", a, b) + return fmt.Sprintf("(%v) %v (%v)", a, op, b) } } @@ -1248,123 +1299,73 @@ func isMFADatabaseAccessRequired(ctx context.Context, tc *client.TeleportClient, // // If logged into multiple databases, returns an error unless one specified // explicitly via --db flag. -func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToDatabase, error) { - profile, err := tc.ProfileStatus() - if err != nil { +func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) (*tlsca.RouteToDatabase, error) { + if route, err := maybePickActiveDatabase(cf, activeRoutes); err == nil && route != nil { + return route, nil + } else if err != nil { return nil, trace.Wrap(err) } - - routes, err := profile.DatabasesForCluster(tc.SiteName) - if err != nil { - return nil, trace.Wrap(err) + // check if any active database can possibly match. + selectors := newDatabaseResourceSelectors(cf) + if routes := filterRoutesByPrefix(activeRoutes, selectors.name); len(routes) == 0 { + // no match is possible. + return nil, trace.NotFound(formatDBNotLoggedIn(cf.SiteName, selectors)) } - if len(routes) == 0 { - return nil, trace.NotFound("please login using 'tsh db login' first") - } - - routes, databases, err := filterActiveDatabases(cf.Context, tc, routes) + db, err := getDatabaseByNameOrDiscoveredName(cf, tc, activeRoutes) if err != nil { return nil, trace.Wrap(err) } - - if len(routes) != 1 { - // error - we need exactly one route. - selectors := resourceSelectors{ - kind: "database", - name: cf.DatabaseService, - labels: cf.Labels, - query: cf.PredicateExpression, - } - if len(routes) == 0 { - return nil, trace.NotFound("not logged into %v", selectors) - } - if len(databases) == 0 { - // if not already given, try to fetch them so we can print full - // the full `tsh db ls -v` table of ambiguously matching active DBs. - databases, err = listActiveDatabases(cf.Context, tc, routes) - if err != nil { - return nil, trace.Wrap(err) + if route, ok := findActiveDatabase(db.GetName(), activeRoutes); ok { + return &route, nil + } + return nil, trace.NotFound(formatDBNotLoggedIn(cf.SiteName, selectors)) +} + +// maybePickActiveDatabase tries to pick a database automatically when selectors +// are not given, or by an exact name match of an active database when neither +// labels nor query are given. +// The route returned may be nil, indicating an active route could not be +// picked. +func maybePickActiveDatabase(cf *CLIConf, activeRoutes []tlsca.RouteToDatabase) (*tlsca.RouteToDatabase, error) { + selectors := newDatabaseResourceSelectors(cf) + if selectors.query == "" && selectors.labels == "" { + if selectors.name == "" { + switch len(activeRoutes) { + case 0: + return nil, trace.NotFound(formatDBNotLoggedIn(cf.SiteName, selectors)) + case 1: + log.Debugf("Auto-selecting the only active database %q", activeRoutes[0].ServiceName) + return &activeRoutes[0], nil + default: + return nil, trace.BadParameter(formatChooseActiveDB(activeRoutes)) } } - errMsg := formatAmbiguousDB(cf, selectors, databases) - return nil, trace.BadParameter(errMsg) - } - - route := &routes[0] - // If database user or name were provided on the CLI, - // override the default ones. - if cf.DatabaseUser != "" { - route.Username = cf.DatabaseUser - } - if cf.DatabaseName != "" { - route.Database = cf.DatabaseName - } - return route, nil -} - -// filterActiveDatabases takes a list of active database routes and returns a -// filtered list and, possibly, their corresponding types.Databases. -// Callers should therefore not assume that the types.Databases are populated. -// Filtering is done by matching on database name prefix, label, and query -// predicate selectors from the Teleport client. -// If an active database name matches exactly, all other active databases are -// filtered out - this is to avoid requiring additional selectors -// when a user gives an exact database name. -func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, types.Databases, error) { - if len(activeRoutes) == 0 { - // nothing to filter - return nil, nil, nil - } - prefix := tc.DatabaseService - if len(tc.Labels) == 0 && tc.PredicateExpression == "" { - // when we have a name but don't have label or predicate query, look for - // a route that matches the name exactly to maybe avoid calling - // ListDatabases API below. - if route, ok := findActiveDatabase(prefix, activeRoutes); ok { - return []tlsca.RouteToDatabase{route}, nil, nil + if route, ok := findActiveDatabase(selectors.name, activeRoutes); ok { + log.Debugf("Selected active database %q by name", route.ServiceName) + return &route, nil } } + return nil, nil +} - // make a ListDatabases API call filtered by prefix name - databases, err := listDatabasesByPrefix(ctx, tc, prefix) - if err != nil { - return nil, nil, trace.Wrap(err) - } - databasesByName := databases.ToMap() - - // when a database matches the prefix fully, look for a - // corresponding active route. - if db, ok := databasesByName[prefix]; ok { - for _, route := range activeRoutes { - if route.ServiceName == db.GetName() { - return []tlsca.RouteToDatabase{route}, types.Databases{db}, nil - } - } - // no active route, but return the fetched databases if the caller is - // interested. - return nil, databases, nil +// getDatabasesForLogout selects databases for logout in "tsh db logout". +func getDatabasesForLogout(cf *CLIConf, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, error) { + selectors := newDatabaseResourceSelectors(cf) + if selectors.IsEmpty() { + // if db name, labels, query was not given, logout of all databases. + return activeRoutes, nil } - - // otherwise, just filter routes to those that match the names of the - // databases. - var selectedRoutes []tlsca.RouteToDatabase - var activeDBs types.Databases - for _, route := range activeRoutes { - if db, ok := databasesByName[route.ServiceName]; ok { - selectedRoutes = append(selectedRoutes, route) - activeDBs = append(activeDBs, db) - } + route, err := pickActiveDatabase(cf, tc, activeRoutes) + if err != nil { + return nil, trace.Wrap(err) } - return selectedRoutes, activeDBs, nil + return []tlsca.RouteToDatabase{*route}, nil } // findActiveDatabase returns a database route and a bool indicating whether // the route was found. func findActiveDatabase(name string, activeRoutes []tlsca.RouteToDatabase) (tlsca.RouteToDatabase, bool) { - if name == "" && len(activeRoutes) == 1 { - return activeRoutes[0], true - } for _, r := range activeRoutes { if r.ServiceName == name { return r, true @@ -1373,11 +1374,58 @@ func findActiveDatabase(name string, activeRoutes []tlsca.RouteToDatabase) (tlsc return tlsca.RouteToDatabase{}, false } +func filterRoutesByPrefix(routes []tlsca.RouteToDatabase, prefix string) []tlsca.RouteToDatabase { + var out []tlsca.RouteToDatabase + for _, r := range routes { + if strings.HasPrefix(r.ServiceName, prefix) { + out = append(out, r) + } + } + return out +} + +func formatStaleDBCert(clusterFlag, name string) string { + return fmt.Sprintf("you are logged into a database that no longer exists in the cluster (remove it with '%v %v')", + formatDatabaseLogoutCommand(clusterFlag), name) +} + +func formatChooseActiveDB(routes []tlsca.RouteToDatabase) string { + var services []string + for _, r := range routes { + services = append(services, r.ServiceName) + } + return fmt.Sprintf("multiple databases are available (%v), please specify one by name, --labels, or --query", + strings.Join(services, ", ")) +} + +func formatDBNotLoggedIn(clusterFlag string, selectors resourceSelectors) string { + if selectors.IsEmpty() { + return fmt.Sprintf( + "please login using '%v' first (use '%v' to see registered databases)", + formatDatabaseLoginCommand(clusterFlag), + formatDatabaseListCommand(clusterFlag), + ) + } + return fmt.Sprintf("not logged into %s", selectors) +} + +func formatDatabaseLogoutCommand(clusterFlag string) string { + return formatTSHCommand("tsh db logout", clusterFlag) +} + +func formatDatabaseLoginCommand(clusterFlag string) string { + return formatTSHCommand("tsh db login", clusterFlag) +} + func formatDatabaseListCommand(clusterFlag string) string { + return formatTSHCommand("tsh db ls", clusterFlag) +} + +func formatTSHCommand(cmd, clusterFlag string) string { if clusterFlag == "" { - return "tsh db ls" + return cmd } - return fmt.Sprintf("tsh db ls --cluster=%v", clusterFlag) + return fmt.Sprintf("%v --cluster=%v", cmd, clusterFlag) } // formatDatabaseConnectCommand formats an appropriate database connection @@ -1637,6 +1685,15 @@ func (r resourceSelectors) IsEmpty() bool { return r.name == "" && r.labels == "" && r.query == "" } +func newDatabaseResourceSelectors(cf *CLIConf) resourceSelectors { + return resourceSelectors{ + kind: "database", + name: cf.DatabaseService, + labels: cf.Labels, + query: cf.PredicateExpression, + } +} + // formatAmbiguityErrTemplate is a helper func that formats an ambiguous // resource error message. func formatAmbiguityErrTemplate(cf *CLIConf, selectors resourceSelectors, listCommand, matchTable, fullNameExample string) string { diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index 85a050aadbd49..223d1069cf49a 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -59,8 +59,7 @@ func TestTshDB(t *testing.T) { testenv.WithInsecureDevMode(t, true) t.Run("Login", testDatabaseLogin) t.Run("List", testListDatabase) - t.Run("FilterActiveDatabases", testFilterActiveDatabases) - t.Run("DatabaseInfo", testDatabaseInfo) + t.Run("DatabaseSelection", testDatabaseSelection) } // testDatabaseLogin tests "tsh db login" command and verifies "tsh db @@ -86,13 +85,6 @@ func testDatabaseLogin(t *testing.T) { cfg.Databases.Enabled = true cfg.Databases.Databases = []servicecfg.Database{ { - Name: "postgres", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - "env": "local", - }, - }, { Name: "postgres-rds-us-west-1-123456789012", Protocol: defaults.ProtocolPostgres, URI: "localhost:5432", @@ -108,22 +100,6 @@ func testDatabaseLogin(t *testing.T) { InstanceID: "postgres", }, }, - }, { - Name: "postgres-rds-us-west-2-123456789012", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - types.DiscoveredNameLabel: "postgres", - "region": "us-west-2", - "env": "prod", - }, - AWS: servicecfg.DatabaseAWS{ - AccountID: "123456789012", - Region: "us-west-2", - RDS: servicecfg.DatabaseAWSRDS{ - InstanceID: "postgres", - }, - }, }, { Name: "mysql", Protocol: defaults.ProtocolMySQL, @@ -217,17 +193,20 @@ func testDatabaseLogin(t *testing.T) { expectErrForEnvCmd: true, // "tsh db env" not supported for DynamoDB. }, { - name: "postgres", - databaseName: "postgres", - // the full db name is also a prefix of other dbs, but a full name - // match should take precedence over prefix matches. + name: "by full name", + databaseName: "postgres-rds-us-west-1-123456789012", + expectCertsLen: 1, + }, + { + name: "by discovered name", + databaseName: "postgres-rds-us-west-1-123456789012", dbSelectors: []string{"postgres"}, expectCertsLen: 1, }, { name: "by labels", - databaseName: "postgres", - dbSelectors: []string{"--labels", "env=local"}, + databaseName: "postgres-rds-us-west-1-123456789012", + dbSelectors: []string{"--labels", "region=us-west-1"}, expectCertsLen: 1, }, { @@ -236,12 +215,6 @@ func testDatabaseLogin(t *testing.T) { dbSelectors: []string{"--query", `labels.env=="prod" && labels.region == "us-west-1"`}, expectCertsLen: 1, }, - { - name: "by prefix name", - databaseName: "postgres-rds-us-west-2-123456789012", - dbSelectors: []string{"postgres-rds-us-west-2"}, - expectCertsLen: 1, - }, } // Note: keystore currently races when multiple tsh clients work in the @@ -534,6 +507,18 @@ func testListDatabase(t *testing.T) { require.Contains(t, captureStdout.String(), "leaf-postgres") } +func TestFormatDatabaseLoginCommand(t *testing.T) { + t.Parallel() + + t.Run("default", func(t *testing.T) { + require.Equal(t, "tsh db login", formatDatabaseLoginCommand("")) + }) + + t.Run("with cluster flag", func(t *testing.T) { + require.Equal(t, "tsh db login --cluster=leaf", formatDatabaseLoginCommand("leaf")) + }) +} + func TestFormatDatabaseListCommand(t *testing.T) { t.Parallel() @@ -920,339 +905,51 @@ func TestGetDefaultDBNameAndUser(t *testing.T) { } } -func testFilterActiveDatabases(t *testing.T) { +func TestResourceSelectors(t *testing.T) { t.Parallel() - // setup some databases and "active" routes to test filtering - - // databases that all have a name starting with with "foo" - fooDB1, fooRoute1 := makeDBConfigAndRoute("foo", map[string]string{"env": "dev", "svc": "fooer"}) - fooDB2, fooRoute2 := makeDBConfigAndRoute("foo-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) - fooDB3, fooRoute3 := makeDBConfigAndRoute("foo-westus-11111", map[string]string{"env": "prod", "region": "westus"}) - - // databases that all have a name starting with with "bar" - barDB1, barRoute1 := makeDBConfigAndRoute("bar", map[string]string{"env": "dev", "svc": "barrer"}) - barDB2, barRoute2 := makeDBConfigAndRoute("bar-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) - - // databases that all have a name starting with with "baz" - bazDB1, bazRoute1 := makeDBConfigAndRoute("baz", map[string]string{"env": "dev", "svc": "bazzer"}) - bazDB2, bazRoute2 := makeDBConfigAndRoute("baz2", map[string]string{"env": "prod", "svc": "bazzer"}) - routes := []tlsca.RouteToDatabase{ - fooRoute1, fooRoute2, fooRoute3, - barRoute1, barRoute2, - bazRoute1, bazRoute2, - } - s := newTestSuite(t, - withRootConfigFunc(func(cfg *servicecfg.Config) { - cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - cfg.Databases.Enabled = true - cfg.Databases.Databases = []servicecfg.Database{ - fooDB1, fooDB2, fooDB3, - barDB1, barDB2, - bazDB1, bazDB2, - } - }), - ) - - // Log into Teleport cluster. - tmpHomePath, _ := mustLogin(t, s) + t.Run("formatting", testResourceSelectorsFormatting) + t.Run("IsEmpty", testResourceSelectorsIsEmpty) +} +func testResourceSelectorsIsEmpty(t *testing.T) { + t.Parallel() tests := []struct { - name, - dbNamePrefix, - labels, - query string - wantAPICall bool - overrideActiveRoutes []tlsca.RouteToDatabase - overrideAPIDatabasesCheckFn func(t *testing.T, databases types.Databases) - wantRoutes []tlsca.RouteToDatabase + desc string + selectors resourceSelectors + wantEmpty bool }{ { - name: "by exact name that is a prefix of others", - dbNamePrefix: fooRoute1.ServiceName, - wantAPICall: false, - wantRoutes: []tlsca.RouteToDatabase{fooRoute1}, - }, - { - name: "by exact name of inactive route that is a prefix of active routes", - dbNamePrefix: fooRoute1.ServiceName, - overrideActiveRoutes: []tlsca.RouteToDatabase{ - fooRoute2, fooRoute3, - barRoute1, barRoute2, - bazRoute1, bazRoute2, - }, - wantAPICall: true, - overrideAPIDatabasesCheckFn: func(t *testing.T, databases types.Databases) { - t.Helper() - require.NotNil(t, databases) - databasesByName := databases.ToMap() - require.Contains(t, databasesByName, fooRoute1.ServiceName) - require.Contains(t, databasesByName, fooRoute2.ServiceName) - require.Contains(t, databasesByName, fooRoute3.ServiceName) - }, - // the inactive route got filtered out, but active routes shouldn't - // have been matched by prefix either. - wantRoutes: nil, - }, - { - name: "by exact name that is not a prefix of others", - dbNamePrefix: fooRoute2.ServiceName, - wantAPICall: false, - wantRoutes: []tlsca.RouteToDatabase{fooRoute2}, - }, - { - name: "by exact name that is a prefix of others with overlapping labels", - dbNamePrefix: bazRoute1.ServiceName, - labels: "svc=bazzer", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{bazRoute1}, - }, - { - name: "by name prefix", - dbNamePrefix: "ba", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{barRoute1, barRoute2, bazRoute1, bazRoute2}, - }, - { - name: "by labels", - labels: "env=dev", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute1, barRoute1, bazRoute1}, - }, - { - name: "by query", - query: `labels.env == "dev"`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute1, barRoute1, bazRoute1}, + desc: "no fields set", + selectors: resourceSelectors{}, + wantEmpty: true, }, { - name: "by name prefix and labels", - dbNamePrefix: "fo", - labels: "env=prod", - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute2, fooRoute3}, + desc: "kind field set", + selectors: resourceSelectors{kind: "x"}, + wantEmpty: true, }, { - name: "by name prefix and query", - dbNamePrefix: "fo", - query: `labels.region == "us-west-1"`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute2}, + desc: "name field set", + selectors: resourceSelectors{name: "x"}, }, { - name: "by labels and query", - labels: "env=dev", - query: `hasPrefix(name, "baz")`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{bazRoute1}, + desc: "labels field set", + selectors: resourceSelectors{labels: "x"}, }, { - name: "by name prefix and labels and query", - dbNamePrefix: "fo", - labels: "env=prod", - query: `labels.region == "westus"`, - wantAPICall: true, - wantRoutes: []tlsca.RouteToDatabase{fooRoute3}, + desc: "query field set", + selectors: resourceSelectors{query: "x"}, }, } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - cf := &CLIConf{ - Context: ctx, - HomePath: tmpHomePath, - DatabaseService: tt.dbNamePrefix, - Labels: tt.labels, - PredicateExpression: tt.query, - } - tc, err := makeClient(cf) - require.NoError(t, err) - activeRoutes := routes - if tt.overrideActiveRoutes != nil { - activeRoutes = tt.overrideActiveRoutes - } - gotRoutes, dbs, err := filterActiveDatabases(ctx, tc, activeRoutes) - require.NoError(t, err) - require.Empty(t, cmp.Diff(tt.wantRoutes, gotRoutes)) - if tt.wantAPICall { - if tt.overrideAPIDatabasesCheckFn != nil { - tt.overrideAPIDatabasesCheckFn(t, dbs) - } else { - require.Equal(t, len(tt.wantRoutes), len(dbs), - "returned routes should have corresponding types.Databases") - for i := range tt.wantRoutes { - require.Equal(t, gotRoutes[i].ServiceName, dbs[i].GetName(), - "route %v does not match corresponding types.Database", i) - } - } - return - } - require.Zero(t, len(dbs), "unexpected API call to ListDatabases") + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + require.Equal(t, test.wantEmpty, test.selectors.IsEmpty()) }) } } -func testDatabaseInfo(t *testing.T) { +func testResourceSelectorsFormatting(t *testing.T) { t.Parallel() - alice, err := types.NewUser("alice@example.com") - require.NoError(t, err) - defaultDBUser := "admin" - defaultDBName := "default" - // add multiple allowed db names/users, to prevent default selection. - // these tests should use the db name/username from either cli flag or - // active cert only. - alice.SetDatabaseUsers([]string{defaultDBUser, "foo"}) - alice.SetDatabaseNames([]string{defaultDBName, "bar"}) - alice.SetRoles([]string{"access"}) - databases := []servicecfg.Database{ - { - Name: "postgres", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - "env": "local", - }, - }, { - Name: "postgres-2", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - StaticLabels: map[string]string{ - "env": "local", - }, - }, { - Name: "mysql", - Protocol: defaults.ProtocolMySQL, - URI: "localhost:3306", - }, { - Name: "cassandra", - Protocol: defaults.ProtocolCassandra, - URI: "localhost:9042", - }, { - Name: "snowflake", - Protocol: defaults.ProtocolSnowflake, - URI: "localhost.snowflakecomputing.com", - }, { - Name: "mongo", - Protocol: defaults.ProtocolMongoDB, - URI: "localhost:27017", - }, { - Name: "mssql", - Protocol: defaults.ProtocolSQLServer, - URI: "localhost:1433", - }, { - Name: "dynamodb", - Protocol: defaults.ProtocolDynamoDB, - URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. - AWS: servicecfg.DatabaseAWS{ - AccountID: "123456789012", - ExternalID: "123123123", - Region: "us-west-1", - }, - }} - s := newTestSuite(t, - withRootConfigFunc(func(cfg *servicecfg.Config) { - cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, alice) - cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - // separate MySQL port with TLS routing. - // set the public address to be sure even on v2+, tsh clients will see the separate port. - mySQLAddr := localListenerAddr() - cfg.Proxy.MySQLAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: mySQLAddr} - cfg.Proxy.MySQLPublicAddrs = []utils.NetAddr{{AddrNetwork: "tcp", Addr: mySQLAddr}} - cfg.Databases.Enabled = true - cfg.Databases.Databases = databases - }), - ) - s.user = alice - // Log into Teleport cluster. - tmpHomePath, _ := mustLogin(t, s) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - t.Run("newDatabaseInfo", func(t *testing.T) { - for _, db := range databases { - require.NotEmpty(t, db.Name) - require.NotEmpty(t, db.Protocol) - route := tlsca.RouteToDatabase{ - ServiceName: db.Name, - Protocol: db.Protocol, - Username: defaultDBUser, - Database: defaultDBName, - } - t.Run(route.ServiceName, func(t *testing.T) { - t.Run("with route", func(t *testing.T) { - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - tracer: tracing.NoopTracer(teleport.ComponentTSH), - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, &route) - require.NoError(t, err) - require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - require.Equal(t, route, dbInfo.RouteToDatabase) - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "database should have been fetched and cached") - }) - t.Run("without route", func(t *testing.T) { - err = Run(ctx, []string{"db", "login", route.ServiceName, - "--db-user", route.Username, - "--db-name", route.Database, - }, setHomePath(tmpHomePath)) - require.NoError(t, err) - cf := &CLIConf{ - Context: ctx, - TracingProvider: tracing.NoopProvider(), - HomePath: tmpHomePath, - tracer: tracing.NoopTracer(teleport.ComponentTSH), - DatabaseService: route.ServiceName, - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := newDatabaseInfo(cf, tc, nil) - require.NoError(t, err) - require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched") - db, err := dbInfo.GetDatabase(cf, tc) - require.NoError(t, err) - require.Equal(t, route, dbInfo.RouteToDatabase) - require.Equal(t, route.ServiceName, db.GetName()) - require.Equal(t, route.Protocol, db.GetProtocol()) - require.Equal(t, dbInfo.database, db, "cached database should be the same") - }) - }) - } - }) - t.Run("getDatabaseInfo", func(t *testing.T) { - // login to "postgres-2" db. - err = Run(ctx, []string{"db", "login", "postgres-2"}, setHomePath(tmpHomePath)) - require.NoError(t, err) - cf := &CLIConf{ - Context: ctx, - HomePath: tmpHomePath, - // select the other db, "postgres", which was not logged into. - DatabaseService: "postgres", - } - tc, err := makeClient(cf) - require.NoError(t, err) - dbInfo, err := getDatabaseInfo(cf, tc) - require.NoError(t, err) - require.NotNil(t, dbInfo) - // verify that the active login route for "postgres-2" was not used - // instead of fetching info for the "postgres" db. - require.Equal(t, "postgres", dbInfo.ServiceName) - require.Equal(t, defaults.ProtocolPostgres, dbInfo.Protocol) - require.Equal(t, defaultDBUser, dbInfo.Username) - require.Equal(t, defaultDBName, dbInfo.Database) - require.NotNil(t, dbInfo.database) - }) -} - -func TestResourceSelectorsFormatting(t *testing.T) { tests := []struct { testName string selectors resourceSelectors @@ -1316,7 +1013,12 @@ func makeDBConfigAndRoute(name string, staticLabels map[string]string) (servicec URI: "localhost:5432", StaticLabels: staticLabels, } - route := tlsca.RouteToDatabase{ServiceName: name} + route := tlsca.RouteToDatabase{ + ServiceName: name, + Protocol: defaults.ProtocolPostgres, + Username: "alice", + Database: "postgres", + } return db, route } @@ -1346,6 +1048,22 @@ func TestChooseOneDatabase(t *testing.T) { URI: "uri", }) require.NoError(t, err) + db3, err := types.NewDatabaseV3(types.Metadata{ + Name: "my-db-with-some-suffix", + Labels: map[string]string{"foo": "bar", types.DiscoveredNameLabel: "my-db"}, + }, types.DatabaseSpecV3{ + Protocol: "protocol", + URI: "uri", + }) + require.NoError(t, err) + db4, err := types.NewDatabaseV3(types.Metadata{ + Name: "my-db-with-some-other-suffix", + Labels: map[string]string{"foo": "bar", types.DiscoveredNameLabel: "my-db"}, + }, types.DatabaseSpecV3{ + Protocol: "protocol", + URI: "uri", + }) + require.NoError(t, err) tests := []struct { desc string databases types.Databases @@ -1362,6 +1080,11 @@ func TestChooseOneDatabase(t *testing.T) { databases: types.Databases{db0, db1, db2}, wantDB: db0, }, + { + desc: "multiple databases to choose from with unambiguous discovered name match", + databases: types.Databases{db1, db2, db3}, + wantDB: db3, + }, { desc: "zero databases to choose from is an error", wantErrContains: `database "my-db" with labels "foo=bar" with query (hasPrefix(name, "my-db")) not found, use 'tsh db ls --cluster=local-site'`, @@ -1371,6 +1094,11 @@ func TestChooseOneDatabase(t *testing.T) { databases: types.Databases{db1, db2}, wantErrContains: `database "my-db" with labels "foo=bar" with query (hasPrefix(name, "my-db")) matches multiple databases`, }, + { + desc: "ambiguous discovered name databases is an error", + databases: types.Databases{db3, db4}, + wantErrContains: `database "my-db" with labels "foo=bar" with query (hasPrefix(name, "my-db")) matches multiple databases`, + }, } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1380,11 +1108,12 @@ func TestChooseOneDatabase(t *testing.T) { Context: ctx, TracingProvider: tracing.NoopProvider(), tracer: tracing.NoopTracer(teleport.ComponentTSH), + DatabaseService: "my-db", Labels: "foo=bar", PredicateExpression: `hasPrefix(name, "my-db")`, SiteName: "local-site", } - db, err := chooseOneDatabase(cf, "my-db", test.databases) + db, err := chooseOneDatabase(cf, test.databases) if test.wantErrContains != "" { require.ErrorContains(t, err, test.wantErrContains) return @@ -1395,3 +1124,494 @@ func TestChooseOneDatabase(t *testing.T) { }) } } + +func TestMaybePickActiveDatabase(t *testing.T) { + t.Parallel() + x := tlsca.RouteToDatabase{ServiceName: "x"} + y := tlsca.RouteToDatabase{ServiceName: "y"} + z := tlsca.RouteToDatabase{ServiceName: "z"} + tests := []struct { + desc string + svcName, labels, query string + routes []tlsca.RouteToDatabase + wantRoute *tlsca.RouteToDatabase + wantErr string + }{ + { + desc: "does nothing if labels given", + routes: []tlsca.RouteToDatabase{x}, + svcName: "x", + labels: "env=dev", + }, + { + desc: "does nothing if query given", + svcName: "x", + routes: []tlsca.RouteToDatabase{x}, + query: `name == "x"`, + }, + { + desc: "picks an active route by name", + svcName: "y", + routes: []tlsca.RouteToDatabase{x, y, z}, + wantRoute: &y, + }, + { + desc: "does nothing if only unmatched name is given", + svcName: "y", + routes: []tlsca.RouteToDatabase{x, z}, + }, + { + desc: "picks the only active route without selectors", + routes: []tlsca.RouteToDatabase{x}, + wantRoute: &x, + }, + { + desc: "no routes and no selectors is an error", + routes: []tlsca.RouteToDatabase{}, + wantErr: "please login", + }, + { + desc: "many routes and no selectors is an error", + routes: []tlsca.RouteToDatabase{x, y, z}, + wantErr: "multiple databases", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + cf := &CLIConf{ + DatabaseService: test.svcName, + Labels: test.labels, + PredicateExpression: test.query, + } + route, err := maybePickActiveDatabase(cf, test.routes) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, test.wantRoute, route) + }) + } +} + +func TestFindActiveDatabase(t *testing.T) { + t.Parallel() + x := tlsca.RouteToDatabase{ServiceName: "x", Protocol: "postgres", Username: "alice", Database: "postgres"} + y := tlsca.RouteToDatabase{ServiceName: "y", Protocol: "postgres", Username: "alice", Database: "postgres"} + z := tlsca.RouteToDatabase{ServiceName: "z", Protocol: "postgres", Username: "alice", Database: "postgres"} + tests := []struct { + desc string + name string + routes []tlsca.RouteToDatabase + wantOK bool + wantRoute tlsca.RouteToDatabase + }{ + { + desc: "zero routes", + name: "x", + }, + { + desc: "no name with zero routes", + }, + { + desc: "no name with one route", + routes: []tlsca.RouteToDatabase{x}, + }, + { + desc: "no name with many routes", + routes: []tlsca.RouteToDatabase{x, y}, + }, + { + desc: "name in routes", + name: "x", + routes: []tlsca.RouteToDatabase{x, y}, + wantOK: true, + wantRoute: x, + }, + { + desc: "name not in routes", + name: "x", + routes: []tlsca.RouteToDatabase{y, z}, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + route, ok := findActiveDatabase(test.name, test.routes) + require.Equal(t, test.wantOK, ok) + require.Equal(t, test.wantRoute, route) + }) + } +} + +// testDatabaseSelection tests database selection by name, prefix name, labels, +// query, etc. +func testDatabaseSelection(t *testing.T) { + t.Parallel() + // setup some databases and "active" routes to test filtering + + // databases that all have a name starting with with "foo" + fooDB1, fooRoute1 := makeDBConfigAndRoute("foo", map[string]string{"env": "dev", "svc": "fooer"}) + fooRDSDB, fooRDSRoute := makeDBConfigAndRoute("foo-rds-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1", types.DiscoveredNameLabel: "foo-rds"}) + fooRDSCustomDB, fooRDSCustomRoute := makeDBConfigAndRoute("foo-rds-custom-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1", types.DiscoveredNameLabel: "foo-rds-custom"}) + // a route that isn't registered anymore, like when a user has logged into + // a db that isn't registered in the cluster anymore. + _, staleRoute := makeDBConfigAndRoute("stale", map[string]string{"env": "dev", "svc": "fooer"}) + + // databases that all have a name starting with with "bar" + barRDSDB1, barRDSRoute1 := makeDBConfigAndRoute("bar-rds-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1", types.DiscoveredNameLabel: "bar-rds"}) + barRDSDB2, barRDSRoute2 := makeDBConfigAndRoute("bar-rds-us-west-2-123456789012", map[string]string{"env": "prod", "region": "us-west-2", types.DiscoveredNameLabel: "bar-rds"}) + + activeRoutes := []tlsca.RouteToDatabase{ + fooRoute1, fooRDSRoute, fooRDSCustomRoute, staleRoute, + barRDSRoute1, barRDSRoute2, + } + + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetDatabaseUsers([]string{"alice", "bob"}) + alice.SetDatabaseNames([]string{"postgres", "other"}) + alice.SetRoles([]string{"access"}) + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, alice) + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + fooDB1, fooRDSDB, fooRDSCustomDB, + barRDSDB1, barRDSDB2, + } + }), + ) + s.user = alice + + // Log into Teleport cluster. + tmpHomePath, _ := mustLogin(t, s) + + t.Run("GetDatabasesForLogout", func(t *testing.T) { + t.Parallel() + tests := []struct { + name, + svcName, + labels, + query string + wantRoutes []tlsca.RouteToDatabase + wantErr string + }{ + { + name: "by exact name", + svcName: fooRDSRoute.ServiceName, + wantRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + }, + { + name: "by exact discovered name", + svcName: "foo-rds", + wantRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + }, + { + name: "by labels", + labels: "region=us-west-2", + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute2}, + }, + { + name: "by query", + query: `labels.region == "us-west-2"`, + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute2}, + }, + { + name: "by exact name of unregistered database", + svcName: staleRoute.ServiceName, + wantRoutes: []tlsca.RouteToDatabase{staleRoute}, + }, + { + name: "by exact discovered name that is ambiguous", + svcName: "bar-rds", + wantErr: "matches multiple", + }, + { + name: "by exact discovered name with labels", + svcName: "bar-rds", + labels: "region=us-west-1", + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + }, + { + name: "by exact discovered name with query", + svcName: "bar-rds", + query: `labels.region == "us-west-1"`, + wantRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + }, + { + name: "all", + wantRoutes: activeRoutes, + }, + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: tt.svcName, + Labels: tt.labels, + PredicateExpression: tt.query, + } + tc, err := makeClient(cf) + require.NoError(t, err) + gotRoutes, err := getDatabasesForLogout(cf, tc, activeRoutes) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + require.Empty(t, cmp.Diff(tt.wantRoutes, gotRoutes)) + }) + } + }) + + t.Run("GetDatabaseInfo", func(t *testing.T) { + t.Parallel() + tests := []struct { + desc string + svcName, labels, query string + dbUser, dbName string + activeRoutes []tlsca.RouteToDatabase + wantRoute tlsca.RouteToDatabase + wantActive bool + wantErr string + }{ + { + desc: "by exact name", + svcName: "foo", + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRoute1, + }, + { + desc: "by exact name of active db", + svcName: "foo", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: fooRoute1, + wantActive: true, + }, + { + desc: "by exact name of active db overriding user and schema", + svcName: "foo", + dbUser: "bob", + dbName: "other", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: tlsca.RouteToDatabase{ServiceName: "foo", Protocol: "postgres", Username: "bob", Database: "other"}, + wantActive: true, + }, + { + desc: "by exact name that is a prefix of an active db", + svcName: "foo", + dbUser: "alice", + dbName: "postgres", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + wantRoute: fooRoute1, + }, + { + desc: "by exact discovered name", + svcName: "foo-rds", + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRDSRoute, + }, + { + desc: "by labels", + labels: "env=dev,svc=fooer", + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRoute1, + }, + { + desc: "by labels and active route", + labels: "env=dev,svc=fooer", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: fooRoute1, + wantActive: true, + }, + { + desc: "by query", + query: `name=="foo" && labels.env=="dev" && labels.svc=="fooer"`, + dbUser: "alice", + dbName: "postgres", + wantRoute: fooRoute1, + }, + { + desc: "by query and active route", + query: `name == "foo" && labels.env=="dev" && labels.svc=="fooer"`, + activeRoutes: []tlsca.RouteToDatabase{fooRoute1}, + wantRoute: fooRoute1, + wantActive: true, + }, + { + desc: "by ambiguous exact discovered name", + svcName: "bar-rds", + wantErr: "matches multiple databases", + }, + { + desc: "resolves ambiguous exact discovered name by label", + svcName: "bar-rds", + labels: "region=us-west-1", + dbUser: "alice", + dbName: "postgres", + wantRoute: barRDSRoute1, + }, + { + desc: "resolves ambiguous exact discovered name by query", + svcName: "bar-rds", + query: `labels.region=="us-west-2"`, + dbUser: "alice", + dbName: "postgres", + wantRoute: barRDSRoute2, + }, + { + desc: "by name of db that does not exist", + svcName: "foo-rds-", + wantErr: `"foo-rds-" not found, use 'tsh db ls' to see registered databases`, + }, + { + desc: "by name of db that does not exist and is not active", + svcName: "foo-rds-", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + wantErr: `"foo-rds-" not found, use 'tsh db ls' to see registered databases`, + }, + { + desc: "by ambiguous labels", + labels: "region=us-west-1", + wantErr: "matches multiple databases", + }, + { + desc: "by ambiguous query", + query: `labels.region == "us-west-1"`, + wantErr: "matches multiple databases", + }, + { + desc: "by exact name of unregistered database", + svcName: staleRoute.ServiceName, + activeRoutes: []tlsca.RouteToDatabase{staleRoute}, + wantErr: `you are logged into a database that no longer exists in the cluster`, + }, + // cases without selectors should try choose to from active databases + { + desc: "no selectors with one active registered db", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute}, + wantRoute: fooRDSRoute, + wantActive: true, + }, + { + desc: "no selectors with zero active registered db", + activeRoutes: []tlsca.RouteToDatabase{staleRoute}, + wantErr: `you are logged into a database that no longer exists in the cluster`, + }, + { + desc: "no selectors with multiple active registered db", + activeRoutes: []tlsca.RouteToDatabase{fooRoute1, fooRDSRoute}, + wantErr: "multiple databases are available", + }, + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: test.svcName, + Labels: test.labels, + PredicateExpression: test.query, + DatabaseUser: test.dbUser, + DatabaseName: test.dbName, + } + tc, err := makeClient(cf) + require.NoError(t, err) + info, err := getDatabaseInfo(cf, tc, test.activeRoutes) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, test.wantRoute, info.RouteToDatabase) + db, err := info.GetDatabase(cf.Context, tc) + require.NoError(t, err) + require.Equal(t, info.ServiceName, db.GetName()) + require.Equal(t, info.Protocol, db.GetProtocol()) + require.Equal(t, db, info.database, "database should have been fetched and cached") + require.Equal(t, test.wantActive, info.isActive) + }) + } + }) + + t.Run("PickActiveDatabase", func(t *testing.T) { + t.Parallel() + tests := []struct { + desc string + activeRoutes []tlsca.RouteToDatabase + dbName string + wantRoute tlsca.RouteToDatabase + wantErr string + }{ + { + desc: "pick active db without selector", + activeRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + wantRoute: barRDSRoute1, + }, + { + desc: "pick active db with discovered name selector", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute, barRDSRoute1}, + dbName: "foo-rds", + wantRoute: fooRDSRoute, + }, + { + desc: "pick active db with exact name selector", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute, barRDSRoute1}, + dbName: fooRDSRoute.ServiceName, + wantRoute: fooRDSRoute, + }, + { + desc: "pick inactive db with selector", + dbName: "foo-rds", + activeRoutes: []tlsca.RouteToDatabase{barRDSRoute1}, + wantErr: `not logged into database "foo-rds"`, + }, + { + desc: "no active db", + activeRoutes: []tlsca.RouteToDatabase{}, + wantErr: "please login using 'tsh db login' first", + }, + { + desc: "multiple active db without selector", + activeRoutes: []tlsca.RouteToDatabase{fooRDSRoute, barRDSRoute1}, + wantErr: "multiple databases are available", + }, + } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: test.dbName, + } + tc, err := makeClient(cf) + require.NoError(t, err) + route, err := pickActiveDatabase(cf, tc, test.activeRoutes) + if test.wantErr != "" { + require.ErrorContains(t, err, test.wantErr) + return + } + require.NoError(t, err) + require.NotNil(t, route) + require.Equal(t, test.wantRoute, *route) + }) + } + }) +} diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 134942a4c97da..c7fd0f77e0da1 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -365,7 +365,11 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := getDatabaseInfo(cf, tc) + routes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + dbInfo, err := getDatabaseInfo(cf, tc, routes) if err != nil { return trace.Wrap(err) } @@ -490,7 +494,7 @@ func onProxyCommandDB(cf *CLIConf) error { func maybeAddDBUserPassword(cf *CLIConf, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { if dbInfo.Protocol == defaults.ProtocolCassandra { - db, err := dbInfo.GetDatabase(cf, tc) + db, err := dbInfo.GetDatabase(cf.Context, tc) if err != nil { return nil, trace.Wrap(err) } @@ -890,7 +894,7 @@ var dbProxyTpl = template.Must(template.New("").Parse(`Started DB proxy on {{.ad {{if .randomPort}}To avoid port randomization, you can choose the listening port using the --port flag. {{end}} ` + dbProxyConnectAd + ` -Use following credentials to connect to the {{.database}} proxy: +Use the following credentials to connect to the {{.database}} proxy: ca_file={{.ca}} cert_file={{.cert}} key_file={{.key}}