diff --git a/api/types/constants.go b/api/types/constants.go index d77c85eb116de..c81c93e017562 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -603,6 +603,12 @@ const ( // See also TeleportNamespace and TeleportInternalLabelPrefix. TeleportHiddenLabelPrefix = "teleport.hidden/" + // DiscoveredNameLabel is a resource metadata label name used to identify + // the discovered name of a resource, i.e. the name of a resource before a + // uniquely distinguishing suffix is added by the discovery service. + // See: RFD 129 - Avoid Discovery Resource Name Collisions. + DiscoveredNameLabel = TeleportInternalLabelPrefix + "discovered-name" + // BotLabel is a label used to identify a resource used by a certificate renewal bot. BotLabel = TeleportInternalLabelPrefix + "bot" diff --git a/docs/pages/reference/predicate-language.mdx b/docs/pages/reference/predicate-language.mdx index 4d70344ecb2b3..954cbeab1a47d 100644 --- a/docs/pages/reference/predicate-language.mdx +++ b/docs/pages/reference/predicate-language.mdx @@ -68,6 +68,7 @@ The language also supports the following functions: | `exists(labels["env"])` | resources with a label key `env`; label value unchecked | | `!exists(labels["env"])` | resources without a label key `env`; label value unchecked | | `search("foo", "bar", "some phrase")` | fuzzy match against common resource fields | +| `hasPrefix(name, "foo")` | resources with a name that starts with the prefix `foo` | See some [examples](cli.mdx#filter-examples) of the different ways you can filter resources. diff --git a/lib/client/db/dbcmd/dbcmd.go b/lib/client/db/dbcmd/dbcmd.go index cea69d85b3c8e..deb7d8525a6c3 100644 --- a/lib/client/db/dbcmd/dbcmd.go +++ b/lib/client/db/dbcmd/dbcmd.go @@ -330,7 +330,7 @@ func (c *CLICommandBuilder) getMySQLOracleCommand() (*exec.Cmd, error) { // We save configuration to ~/.my.cnf, but on Windows that file is not read, // see tables 4.1 and 4.2 on https://dev.mysql.com/doc/refman/8.0/en/option-files.html. // We instruct mysql client to use use that file with --defaults-extra-file. - configPath, err := mysql.DefaultConfigPath() + configPath, err := mysql.DefaultConfigPath(c.tc.HomePath) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/db/mysql/optionfile.go b/lib/client/db/mysql/optionfile.go index 7bcee08192080..e7ee7344d4159 100644 --- a/lib/client/db/mysql/optionfile.go +++ b/lib/client/db/mysql/optionfile.go @@ -43,8 +43,13 @@ type OptionFile struct { path string } -func DefaultConfigPath() (string, error) { - // Default location is .my.cnf file in the user's home directory. +// DefaultConfigPath returns the default config path, which is .my.cnf file in +// the user's home directory. Home dir is determined by environment if not +// supplied as an argument. +func DefaultConfigPath(home string) (string, error) { + if home != "" { + return filepath.Join(home, mysqlOptionFile), nil + } home, err := os.UserHomeDir() if err != nil || home == "" { usr, err := utils.CurrentUser() @@ -58,8 +63,8 @@ func DefaultConfigPath() (string, error) { } // Load loads MySQL option file from the default location. -func Load() (*OptionFile, error) { - cnfPath, err := DefaultConfigPath() +func Load(home string) (*OptionFile, error) { + cnfPath, err := DefaultConfigPath(home) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/db/postgres/servicefile.go b/lib/client/db/postgres/servicefile.go index 6361c7d91f25c..c9de6c392226c 100644 --- a/lib/client/db/postgres/servicefile.go +++ b/lib/client/db/postgres/servicefile.go @@ -43,23 +43,37 @@ type ServiceFile struct { path string } -// Load loads Postgres connection service file from the default location. -func Load() (*ServiceFile, error) { +// DefaultConfigPath returns the default config path, which is .pg_service.conf +// file in the user's home directory. Home dir is determined by environment if +// not supplied as an argument. +func defaultConfigPath(home string) (string, error) { // Default location is .pg_service.conf file in the user's home directory. // TODO(r0mant): Check PGSERVICEFILE and PGSYSCONFDIR env vars as well. + if home != "" { + return filepath.Join(home, pgServiceFile), nil + } home, err := os.UserHomeDir() if err != nil || home == "" { - user, err := utils.CurrentUser() + usr, err := utils.CurrentUser() if err != nil { - return nil, trace.ConvertSystemError(err) + return "", trace.ConvertSystemError(err) } - home = user.HomeDir + home = usr.HomeDir } - return LoadFromPath(filepath.Join(home, pgServiceFile)) + return filepath.Join(home, pgServiceFile), nil +} + +// Load loads Postgres connection service file from the default location. +func Load(home string) (*ServiceFile, error) { + cnfPath, err := defaultConfigPath(home) + if err != nil { + return nil, trace.Wrap(err) + } + return LoadFromPath(cnfPath) } -// LoadFromPath loads Posrtgres connection service file from the specified path. +// LoadFromPath loads Postgres connection service file from the specified path. func LoadFromPath(path string) (*ServiceFile, error) { // Loose load will ignore file not found error. iniFile, err := ini.LooseLoad(path) diff --git a/lib/client/db/profile.go b/lib/client/db/profile.go index af63e4c1c2a49..d37221cea9881 100644 --- a/lib/client/db/profile.go +++ b/lib/client/db/profile.go @@ -45,7 +45,7 @@ func Add(ctx context.Context, tc *client.TeleportClient, db tlsca.RouteToDatabas if !IsSupported(db) { return nil } - profileFile, err := load(db) + profileFile, err := load(tc, db) if err != nil { return trace.Wrap(err) } @@ -98,7 +98,7 @@ func New(tc *client.TeleportClient, db tlsca.RouteToDatabase, clientProfile clie // Env returns environment variables for the specified database profile. func Env(tc *client.TeleportClient, db tlsca.RouteToDatabase) (map[string]string, error) { - profileFile, err := load(db) + profileFile, err := load(tc, db) if err != nil { return nil, trace.Wrap(err) } @@ -114,7 +114,7 @@ func Delete(tc *client.TeleportClient, db tlsca.RouteToDatabase) error { if !IsSupported(db) { return nil } - profileFile, err := load(db) + profileFile, err := load(tc, db) if err != nil { return trace.Wrap(err) } @@ -138,12 +138,12 @@ func IsSupported(db tlsca.RouteToDatabase) bool { } // load loads the appropriate database connection profile. -func load(db tlsca.RouteToDatabase) (profile.ConnectProfileFile, error) { +func load(tc *client.TeleportClient, db tlsca.RouteToDatabase) (profile.ConnectProfileFile, error) { switch db.Protocol { case defaults.ProtocolPostgres: - return postgres.Load() + return postgres.Load(tc.HomePath) case defaults.ProtocolMySQL: - return mysql.Load() + return mysql.Load(tc.HomePath) } return nil, trace.BadParameter("unsupported database protocol %q", db.Protocol) diff --git a/lib/services/parser.go b/lib/services/parser.go index 863d290dcb134..d2fccdcb30129 100644 --- a/lib/services/parser.go +++ b/lib/services/parser.go @@ -741,6 +741,22 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, return predicate.Equals(a, b) } } + predPrefix := func(a interface{}, prefix string) predicate.BoolPredicate { + switch aval := a.(type) { + case label: + return func() bool { + return strings.HasPrefix(aval.value, prefix) + } + case string: + return func() bool { + return strings.HasPrefix(aval, prefix) + } + default: + return func() bool { + return false + } + } + } p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ @@ -753,7 +769,8 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, }, }, Functions: map[string]interface{}{ - "equals": predEquals, + "hasPrefix": predPrefix, + "equals": predEquals, // search allows fuzzy matching against select field values. "search": func(searchVals ...string) predicate.BoolPredicate { return func() bool { diff --git a/lib/services/parser_test.go b/lib/services/parser_test.go index 55a4f1ef999c5..7d1b258cddd43 100644 --- a/lib/services/parser_test.go +++ b/lib/services/parser_test.go @@ -192,6 +192,11 @@ func TestNewResourceParser(t *testing.T) { `search("os", "mac", "prod")`, `search()`, `!search("_")`, + // Test hasPrefix. + `hasPrefix(name, "")`, + `hasPrefix(name, "test-h")`, + `!hasPrefix(name, "foo")`, + `hasPrefix(resource.metadata.labels["env"], "pro")`, // Test exists. `exists(labels.env)`, `!exists(labels.undefined)`, @@ -206,6 +211,7 @@ func TestNewResourceParser(t *testing.T) { `labels.os == "mac" && name == "test-hostname" && search("v8")`, `exists(labels.env) && labels["env"] != "qa"`, `search("does", "not", "exist") || resource.spec.addr == "_" || labels.version == "v8"`, + `hasPrefix(labels.os, "m") && !hasPrefix(labels.env, "dev") && name == "test-hostname"`, // Test operator precedence `exists(labels.env) || (exists(labels.os) && labels.os != "mac")`, `exists(labels.env) || exists(labels.os) && labels.os != "mac"`, @@ -233,6 +239,7 @@ func TestNewResourceParser(t *testing.T) { `equals(resource.metadata.labels["env"], "wrong-value")`, `equals(resource.spec.hostname, "wrong-value")`, `search("mac", "not-found")`, + `hasPrefix(name, "x")`, } for _, expr := range exprs { t.Run(expr, func(t *testing.T) { @@ -269,6 +276,10 @@ func TestNewResourceParser(t *testing.T) { `exists(labels.env, "too", "many")`, `search(1,2)`, `"just-string"`, + `hasPrefix(1, 2)`, + `hasPrefix(name)`, + `hasPrefix(name, 1)`, + `hasPrefix(name, "too", "many")`, "", } for _, expr := range exprs { diff --git a/tool/teleport/testenv/test_server.go b/tool/teleport/testenv/test_server.go index 7483e0ccfbff9..d2e92495e2934 100644 --- a/tool/teleport/testenv/test_server.go +++ b/tool/teleport/testenv/test_server.go @@ -34,6 +34,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" @@ -60,6 +61,44 @@ func init() { modules.SetModules(&cliModules{}) } +// WithInsecureDevMode is a test helper that sets insecure dev mode and resets +// it in test cleanup. +// It is NOT SAFE to use in parallel tests, because it modifies a global. +// To run insecure dev mode tests in parallel, group them together under a +// parent test and then run them as parallel subtests. +// and call WithInsecureDevMode before running all the tests in parallel. +func WithInsecureDevMode(t *testing.T, mode bool) { + originalValue := lib.IsInsecureDevMode() + lib.SetInsecureDevMode(mode) + // To detect tests that run in parallel incorrectly, call t.Setenv with a + // dummy env var - that function detects tests with parallel ancestors + // and panics, preventing improper use of this helper. + t.Setenv("WithInsecureDevMode", "1") + t.Cleanup(func() { + lib.SetInsecureDevMode(originalValue) + }) +} + +// WithResyncInterval is a test helper that sets the tunnel resync interval and +// resets it in test cleanup. +// Useful to substantially speedup test cluster setup - passing 0 for the +// interval selects a reasonably fast default of 100ms. +// It is NOT SAFE to use in parallel tests, because it modifies a global. +func WithResyncInterval(t *testing.T, interval time.Duration) { + if interval == 0 { + interval = time.Millisecond * 100 + } + oldResyncInterval := defaults.ResyncInterval + defaults.ResyncInterval = interval + // To detect tests that run in parallel incorrectly, call t.Setenv with a + // dummy env var - that function detects tests with parallel ancestors + // and panics, preventing improper use of this helper. + t.Setenv("WithResyncInterval", "1") + t.Cleanup(func() { + defaults.ResyncInterval = oldResyncInterval + }) +} + // MakeTestServer creates a Teleport Server for testing. func MakeTestServer(t *testing.T, opts ...TestServerOptFunc) (process *service.TeleportProcess) { t.Helper() diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 4d196e3f6e6b9..baf94dcc0c915 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/gravitational/teleport" @@ -250,11 +251,7 @@ func onDatabaseLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := newDatabaseInfo(cf, tc, tlsca.RouteToDatabase{ - ServiceName: cf.DatabaseService, - Username: cf.DatabaseUser, - Database: cf.DatabaseName, - }) + dbInfo, err := newDatabaseInfo(cf, tc, nil) if err != nil { return trace.Wrap(err) } @@ -351,7 +348,11 @@ func onDatabaseLogout(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - activeDatabases, err := profile.DatabasesForCluster(tc.SiteName) + activeRoutes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + logout, _, err := filterActiveDatabases(cf.Context, tc, activeRoutes) if err != nil { return trace.Wrap(err) } @@ -360,34 +361,46 @@ func onDatabaseLogout(cf *CLIConf) error { log.Info("Note: an identity file is in use (`-i ...`); will only update database config files.") } - var logout []tlsca.RouteToDatabase - // If database name wasn't given on the command line, log out of all. - if cf.DatabaseService == "" { - logout = activeDatabases - } else { - for _, db := range activeDatabases { - if db.ServiceName == cf.DatabaseService { - logout = append(logout, db) - } - } - if len(logout) == 0 { - return trace.BadParameter("Not logged into database %q", - tc.DatabaseService) - } - } for _, db := range logout { if err := databaseLogout(tc, db, profile.IsVirtual); err != nil { return trace.Wrap(err) } } - if len(logout) == 1 { - fmt.Println("Logged out of database", logout[0].ServiceName) - } else { - fmt.Println("Logged out of all databases") + msg, err := makeLogoutMessage(cf, logout, activeRoutes) + if err != nil { + return trace.Wrap(err) } + fmt.Fprintln(cf.Stdout(), msg) return nil } +// makeLogoutMessage is a helper func that returns a logout message for the +// result of "tsh db logout". +func makeLogoutMessage(cf *CLIConf, logout, activeRoutes []tlsca.RouteToDatabase) (string, error) { + switch len(logout) { + case 0: + selectors := resourceSelectors{ + kind: "database", + name: cf.DatabaseService, + labels: cf.Labels, + query: cf.PredicateExpression, + } + return "", trace.NotFound("Not logged into %v", selectors) + case 1: + return fmt.Sprintf("Logged out of database %v", logout[0].ServiceName), nil + case len(activeRoutes): + return "Logged out of all databases", nil + default: + names := make([]string, 0, len(logout)) + for _, route := range logout { + names = append(names, route.ServiceName) + } + slices.Sort(names) + nameLines := strings.Join(names, "\n") + return fmt.Sprintf("Logged out of databases:\n%v", nameLines), nil + } +} + func databaseLogout(tc *client.TeleportClient, db tlsca.RouteToDatabase, virtual bool) error { // First remove respective connection profile. err := dbprofile.Delete(tc, db) @@ -413,7 +426,7 @@ func onDatabaseEnv(cf *CLIConf) error { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf) + database, err := pickActiveDatabase(cf, tc) if err != nil { return trace.Wrap(err) } @@ -471,7 +484,7 @@ func onDatabaseConfig(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf) + database, err := pickActiveDatabase(cf, tc) if err != nil { return trace.Wrap(err) } @@ -790,96 +803,75 @@ func onDatabaseConnect(cf *CLIConf) error { } // getDatabaseInfo fetches information about the database from tsh profile if DB -// is active in profile. Otherwise, the ListDatabases endpoint is called. +// is active in profile and no labels or predicate query are given. +// Otherwise, the ListDatabases endpoint is called. func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { - if route, err := pickActiveDatabase(cf); err == nil { - return newDatabaseInfo(cf, tc, *route) - } else if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - return newDatabaseInfo(cf, tc, tlsca.RouteToDatabase{ - ServiceName: cf.DatabaseService, - Username: cf.DatabaseUser, - Database: cf.DatabaseName, - }) -} - -// databaseInfo wraps a RouteToDatabase and the corresponding database. -// Its purpose is to prevent repeated fetches of the same database, by lazily -// fetching and caching the database for use as needed. -type databaseInfo struct { - tlsca.RouteToDatabase - // database corresponds to the db route and may be nil, so use GetDatabase - // instead of accessing it directly. - database types.Database - mu sync.Mutex -} - -// GetDatabase returns the cached database or fetches it using the db route and -// caches the result. -func (d *databaseInfo) GetDatabase(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { - d.mu.Lock() - defer d.mu.Unlock() - if d.database != nil { - return d.database, nil - } - var databases []types.Database - // holding mutex across the api call to avoid multiple redundant api calls. - err := client.RetryWithRelogin(cf.Context, tc, func() error { - var err error - databases, err = tc.ListDatabases(cf.Context, &proto.ListResourcesRequest{ - Namespace: tc.Namespace, - ResourceType: types.KindDatabaseServer, - PredicateExpression: fmt.Sprintf(`name == "%s"`, d.ServiceName), - }) - return trace.Wrap(err) - }) - if err != nil { - return nil, trace.Wrap(err) - } - if len(databases) == 0 { - return nil, trace.NotFound( - "database %q not found, use '%v' to see registered databases", - d.ServiceName, formatDatabaseListCommand(cf.SiteName)) + haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != "" + if !haveSelectors { + // if selectors are given, we might incur an extra ListDatabases API + // call here to match against an active database. + // So try to pick an active database only when we don't have + // selectors. + if route, err := pickActiveDatabase(cf, tc); err == nil { + return newDatabaseInfo(cf, tc, route) + } else if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } } - d.database = databases[0] - return d.database, nil + return newDatabaseInfo(cf, tc, nil) } // newDatabaseInfo makes a new databaseInfo from the given route to the db. // It checks the route and sets defaults as needed for protocol, db user, or db -// name. If the remote database is needed for setting a default, it is retrieved -// by calling ListDatabases API and cached. -func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase) (*databaseInfo, error) { - dbInfo := databaseInfo{RouteToDatabase: route} - if dbInfo.ServiceName == "" { - return nil, trace.BadParameter("missing database service name") - } - if dbInfo.Protocol != "" && dbInfo.Username != "" && dbInfo.Database != "" { - return &dbInfo, nil +// name. If the route is not given or the remote database is needed for setting +// a default, the database is retrieved by calling ListDatabases API and cached. +func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteToDatabase) (*databaseInfo, error) { + dbInfo := &databaseInfo{} + if route != nil { + dbInfo.RouteToDatabase = *route + // the only way we're going to have all this info populated is from an + // active cert. + if dbInfo.ServiceName != "" && dbInfo.Protocol != "" && + dbInfo.Username != "" && dbInfo.Database != "" { + return dbInfo, nil + } } db, err := dbInfo.GetDatabase(cf, tc) if err != nil { return nil, trace.Wrap(err) } + // now ensure the route name and protocol matches the db we fetched. + dbInfo.ServiceName = db.GetName() dbInfo.Protocol = db.GetProtocol() + return dbInfo, dbInfo.checkAndSetPrincipalDefaults(cf, tc, db) +} + +// checkAndSetPrincipalDefaults checks the db route (schema) name and username, +// and sets them to defaults if necessary. +func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { + if cf.DatabaseUser != "" { + d.Username = cf.DatabaseUser + } + if cf.DatabaseName != "" { + d.Database = cf.DatabaseName + } // If database has admin user defined, we're most likely using automatic // user provisioning so default to Teleport username unless database // username was provided explicitly. - if dbInfo.Username == "" && db.GetAdminUser() != "" { + if d.Username == "" && db.GetAdminUser() != "" { log.Debugf("Defaulting to Teleport username %q as database username.", tc.Username) - dbInfo.Username = tc.Username + d.Username = tc.Username } // recheck to see if we can avoid fetching the roleset to set defaults. - needDBUser := dbInfo.Username == "" && role.RequireDatabaseUserMatcher(dbInfo.Protocol) - needDBName := dbInfo.Database == "" && role.RequireDatabaseNameMatcher(dbInfo.Protocol) + needDBUser := d.Username == "" && role.RequireDatabaseUserMatcher(d.Protocol) + needDBName := d.Database == "" && role.RequireDatabaseNameMatcher(d.Protocol) if !needDBUser && !needDBName { - return &dbInfo, nil + return nil } profile, err := tc.ProfileStatus() if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } var proxy *client.ProxyClient @@ -888,33 +880,152 @@ func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteTo return trace.Wrap(err) }) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } defer proxy.Close() checker, err := accessCheckerForRemoteCluster(cf.Context, profile, proxy, tc.SiteName) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } if needDBUser { dbUser, err := getDefaultDBUser(db, checker) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } log.Debugf("Defaulting to the allowed database user %q\n", dbUser) - dbInfo.Username = dbUser + d.Username = dbUser } if needDBName { dbName, err := getDefaultDBName(db, checker) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } log.Debugf("Defaulting to the allowed database name %q\n", dbName) - dbInfo.Database = dbName + d.Database = dbName } + return nil +} - return &dbInfo, nil +// databaseInfo wraps a RouteToDatabase and the corresponding database. +// Its purpose is to prevent repeated fetches of the same database, by lazily +// fetching and caching the database for use as needed. +type databaseInfo struct { + tlsca.RouteToDatabase + // database corresponds to the db route and may be nil, so use GetDatabase + // instead of accessing it directly. + database types.Database + mu sync.Mutex +} + +// GetDatabase returns the cached database or fetches it using the db route and +// caches the result. +func (d *databaseInfo) GetDatabase(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { + if d.ServiceName == "" && cf.DatabaseService == "" && + len(tc.Labels) == 0 && tc.PredicateExpression == "" { + return nil, trace.BadParameter("specify a database service by name, --labels, or --query") + } + d.mu.Lock() + defer d.mu.Unlock() + if d.database != nil { + return d.database, nil + } + // holding mutex across the api call to avoid multiple redundant api calls. + var databases types.Databases + var err error + name := d.ServiceName + if name != "" { + databases, err = listDatabasesByName(cf.Context, tc, name) + } else { + name = cf.DatabaseService + // search by prefix if the db name comes from cli flag instead of cert. + databases, err = listDatabasesByPrefix(cf.Context, tc, name) + } + if err != nil { + return nil, trace.Wrap(err) + } + if len(databases) != 1 { + // error - we need exactly one database. + selectors := resourceSelectors{ + kind: "database", + name: name, + labels: cf.Labels, + query: cf.PredicateExpression, + } + if len(databases) == 0 { + return nil, trace.NotFound( + "%v not found, use '%v' to see registered databases", selectors, + formatDatabaseListCommand(cf.SiteName)) + } + errMsg := formatAmbiguousDB(cf, selectors, databases) + return nil, trace.BadParameter(errMsg) + } + + d.database = databases[0] + return d.database, nil +} + +// listActiveDatabases lists databases that match active (logged in) databases. +func listActiveDatabases(ctx context.Context, tc *client.TeleportClient, routes []tlsca.RouteToDatabase) (types.Databases, error) { + names := make([]string, 0, len(routes)) + for _, r := range routes { + names = append(names, fmt.Sprintf("(name == %q)", r.ServiceName)) + } + predicate := strings.Join(names, "||") + return listDatabasesWithPredicate(ctx, tc, predicate) +} + +// listDatabasesByName lists database that match a given name. +func listDatabasesByName(ctx context.Context, tc *client.TeleportClient, name string) (types.Databases, error) { + predicate := fmt.Sprintf("name == %s", name) + return listDatabasesWithPredicate(ctx, tc, predicate) +} + +// listDatabasesByPrefix lists databases that match a given name prefix. +func listDatabasesByPrefix(ctx context.Context, tc *client.TeleportClient, prefix string) (types.Databases, error) { + predicate := fmt.Sprintf(`hasPrefix(name, "%s")`, prefix) + databases, err := listDatabasesWithPredicate(ctx, tc, predicate) + if err == nil || !utils.IsPredicateError(err) { + return databases, trace.Wrap(err) + } + // predicate error from using hasPrefix expression. + // fallback to listing without the hasPrefix predicate and filtering + // on client side for backwards compatibility. + databases, err = listDatabasesWithPredicate(ctx, tc, "") + if err != nil { + return nil, trace.Wrap(err) + } + var out types.Databases + for _, db := range databases { + if strings.HasPrefix(db.GetName(), prefix) { + out = append(out, db) + } + } + return out, nil +} + +// listDatabasesWithPredicate is a helper func for listing databases using +// a given additional predicate expression. If the teleport client already +// has a predicate expression, the predicates are combined with a logical AND. +func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, predicate string) (types.Databases, error) { + if predicate == "" { + predicate = tc.PredicateExpression + } else if tc.PredicateExpression != "" { + predicate = fmt.Sprintf("(%v) && (%v)", predicate, tc.PredicateExpression) + } + var databases []types.Database + err := client.RetryWithRelogin(ctx, tc, func() error { + var err error + databases, err = tc.ListDatabases(ctx, &proto.ListResourcesRequest{ + Namespace: tc.Namespace, + ResourceType: types.KindDatabaseServer, + PredicateExpression: predicate, + Labels: tc.Labels, + }) + return trace.Wrap(err) + }) + return databases, trace.Wrap(err) } // getDefaultDBUser enumerates the allowed database users for a given database @@ -1105,47 +1216,109 @@ func isMFADatabaseAccessRequired(ctx context.Context, tc *client.TeleportClient, // // If logged into multiple databases, returns an error unless one specified // explicitly via --db flag. -func pickActiveDatabase(cf *CLIConf) (*tlsca.RouteToDatabase, error) { - profile, err := cf.ProfileStatus() +func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToDatabase, error) { + profile, err := tc.ProfileStatus() if err != nil { return nil, trace.Wrap(err) } - activeDatabases, err := profile.DatabasesForCluster(cf.SiteName) + routes, err := profile.DatabasesForCluster(tc.SiteName) if err != nil { return nil, trace.Wrap(err) } - if len(activeDatabases) == 0 { + if len(routes) == 0 { return nil, trace.NotFound("please login using 'tsh db login' first") } - name := cf.DatabaseService - if name == "" { - if len(activeDatabases) > 1 { - var services []string - for _, database := range activeDatabases { - services = append(services, database.ServiceName) - } - return nil, trace.BadParameter("Multiple databases are available (%v), please specify one using CLI argument", - strings.Join(services, ", ")) + routes, databases, err := filterActiveDatabases(cf.Context, tc, routes) + if err != nil { + return nil, trace.Wrap(err) + } + + if len(routes) != 1 { + // error - we need exactly one route. + selectors := resourceSelectors{ + kind: "database", + name: cf.DatabaseService, + labels: cf.Labels, + query: cf.PredicateExpression, + } + if len(routes) == 0 { + return nil, trace.NotFound("not logged into %v", selectors) } - name = activeDatabases[0].ServiceName - } - for _, db := range activeDatabases { - if db.ServiceName == name { - // If database user or name were provided on the CLI, - // override the default ones. - if cf.DatabaseUser != "" { - db.Username = cf.DatabaseUser + if len(databases) == 0 { + // if not already given, try to fetch them so we can print full + // the full `tsh db ls -v` table of ambiguously matching active DBs. + databases, err = listActiveDatabases(cf.Context, tc, routes) + if err != nil { + return nil, trace.Wrap(err) } - if cf.DatabaseName != "" { - db.Database = cf.DatabaseName + } + errMsg := formatAmbiguousDB(cf, selectors, databases) + return nil, trace.BadParameter(errMsg) + } + + route := &routes[0] + // If database user or name were provided on the CLI, + // override the default ones. + if cf.DatabaseUser != "" { + route.Username = cf.DatabaseUser + } + if cf.DatabaseName != "" { + route.Database = cf.DatabaseName + } + return route, nil +} + +// filterActiveDatabases takes a list of active database routes and returns a +// filtered list and, possibly, their corresponding types.Databases. +// Callers should therefore not assume that the types.Databases are populated. +// Filtering is done by matching on database name, label, and query predicate +// selectors from the Teleport client. +// When only database name is given, filtering is done by name prefix, unless +// an active database name matches exactly, in which case all other active +// databases are filtered out - this is to avoid requiring additional selectors +// when a user gives an exact database name. +func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, types.Databases, error) { + prefix := tc.DatabaseService + if prefix == "" && len(activeRoutes) == 1 { + prefix = activeRoutes[0].ServiceName + } + + haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != "" + var selectedRoutes []tlsca.RouteToDatabase + for _, db := range activeRoutes { + if db.ServiceName == prefix && !haveSelectors { + // short-circuit to select the exact match when we don't have + // label or predicate selectors. + return []tlsca.RouteToDatabase{db}, nil, nil + } + if strings.HasPrefix(db.ServiceName, prefix) { + selectedRoutes = append(selectedRoutes, db) + } + } + if len(selectedRoutes) == 0 || !haveSelectors { + // nothing to filter further, avoid making API call. + return selectedRoutes, nil, nil + } + + // make a ListDatabases API call and match on full database name. + databases, err := listDatabasesByPrefix(ctx, tc, prefix) + if err != nil { + return nil, nil, trace.Wrap(err) + } + selectedRoutes = nil + var activeDBs types.Databases + for _, route := range activeRoutes { + for _, db := range databases { + if db.GetName() == route.ServiceName { + selectedRoutes = append(selectedRoutes, route) + activeDBs = append(activeDBs, db) } - return &db, nil } } - return nil, trace.NotFound("Not logged into database %q", name) + return selectedRoutes, activeDBs, nil } func formatDatabaseListCommand(clusterFlag string) string { @@ -1356,6 +1529,68 @@ func getDbCmdAlternatives(clusterFlag string, route tlsca.RouteToDatabase) []str return alts } +// formatAmbiguousDB is a helper func that formats an ambiguous database error +// message. +func formatAmbiguousDB(cf *CLIConf, selectors resourceSelectors, matchedDBs types.Databases) string { + var activeDBs []tlsca.RouteToDatabase + if profile, err := cf.ProfileStatus(); err == nil { + if dbs, err := profile.DatabasesForCluster(cf.SiteName); err == nil { + activeDBs = dbs + } + } + // Pass a nil access checker to avoid making a proxy roundtrip. + // Access info isn't relevant to an ambiguity error anyway. + var checker services.AccessChecker + var sb strings.Builder + verbose := true + showDatabasesAsText(&sb, cf.SiteName, matchedDBs, activeDBs, checker, verbose) + + listCommand := formatDatabaseListCommand(cf.SiteName) + return formatAmbiguityErrTemplate(cf, selectors, listCommand, sb.String()) +} + +// resourceSelectors is a helper struct for gathering up the selectors for a +// resource, as an aggregate of name, labels, and predicate query. +type resourceSelectors struct { + kind, + name, + labels, + query string +} + +// String returns the resource selectors as a formatted string. +// Example: +// command: `tsh db connect foo --labels k1=v1 --query 'labels["k2"]=="v2"'` +// output: database "foo" with labels "k1=v1" with query (labels["k2"]=="v2") +func (r resourceSelectors) String() string { + out := r.kind + if r.name != "" { + out = fmt.Sprintf("%s %q", out, r.name) + } + if len(r.labels) > 0 { + out = fmt.Sprintf("%s with labels %q", out, r.labels) + } + if len(r.query) > 0 { + out = fmt.Sprintf("%s with query (%s)", out, r.query) + } + return strings.TrimSpace(out) +} + +// formatAmbiguityErrTemplate is a helper func that formats an ambiguous +// resource error message. +func formatAmbiguityErrTemplate(cf *CLIConf, selectors resourceSelectors, listCommand, matchTable string) string { + data := map[string]any{ + "command": cf.CommandWithBinary(), + "selectors": strings.TrimSpace(selectors.String()), + "listCommand": strings.TrimSpace(listCommand), + "kind": strings.TrimSpace(selectors.kind), + "matchTable": strings.TrimSpace(matchTable), + } + var sb strings.Builder + _ = ambiguityErrTemplate.Execute(&sb, data) + return sb.String() +} + const ( // dbFormatText prints database configuration in text format. dbFormatText = "text" @@ -1383,9 +1618,7 @@ Please use one of the following commands to connect to the database: {{- range .alternatives}} {{.}}{{end -}} {{- end}}`)) -) -var ( // dbConnectTemplate is the message printed after a successful "tsh db login" on how to connect. dbConnectTemplate = template.Must(template.New("").Parse(`Connection information for database "{{ .name }}" has been saved. @@ -1410,5 +1643,16 @@ You can start a local proxy for database GUI clients: {{ .proxyCommand }} {{end -}} +`)) + + // ambiguityErrTemplate is the error message printed when a resource is + // specified ambiguously by name prefix and/or labels. + ambiguityErrTemplate = template.Must(template.New("").Parse("{{ .selectors }} matches multiple {{ .kind }}s:" + ` + +{{ .matchTable }} + +Hint: use '{{ .listCommand }} -v' or '{{ .listCommand }} --format=[json|yaml]' to list all {{ .kind }}s with full details. +Hint: try selecting the {{ .kind }} with a more specific name (ex: {{ .command }} full-{{ .kind }}-name). +Hint: try selecting the {{ .kind }} with additional --labels or --query predicate. `)) ) diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index 1981a5bbffd95..c03dab6d973d7 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -23,22 +23,21 @@ import ( "crypto/rsa" "encoding/pem" "fmt" - "net" "os" "path/filepath" + "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" @@ -48,91 +47,130 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/tool/teleport/testenv" ) -// TestDatabaseLogin tests "tsh db login" command and verifies "tsh db -// env/config" after login. -func TestDatabaseLogin(t *testing.T) { - tmpHomePath := t.TempDir() - - connector := mockConnector(t) +func TestTshDB(t *testing.T) { + // this speeds up test suite setup substantially, which is where + // tests spend the majority of their time, especially when leaf + // clusters are setup. + testenv.WithResyncInterval(t, 0) + // Proxy uses self-signed certificates in tests. + testenv.WithInsecureDevMode(t, true) + t.Run("Login", testDatabaseLogin) + t.Run("List", testListDatabase) + t.Run("FilterActiveDatabases", testFilterActiveDatabases) +} +// testDatabaseLogin tests "tsh db login" command and verifies "tsh db +// env/config" after login. +func testDatabaseLogin(t *testing.T) { + t.Parallel() alice, err := types.NewUser("alice@example.com") require.NoError(t, err) + // to use default --db-user and --db-name selection, make a user with just + // one of each allowed. alice.SetDatabaseUsers([]string{"admin"}) alice.SetDatabaseNames([]string{"default"}) alice.SetRoles([]string{"access"}) - - authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice), - withAuthConfig(func(cfg *servicecfg.AuthConfig) { - cfg.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - }), - withConfig(func(cfg *servicecfg.Config) { + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, alice) + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) // separate MySQL port with TLS routing. - cfg.Proxy.MySQLAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} - })) - makeTestDatabaseServer(t, authProcess, proxyProcess, - servicecfg.Database{ - Name: "postgres", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - }, servicecfg.Database{ - Name: "mysql", - Protocol: defaults.ProtocolMySQL, - URI: "localhost:3306", - }, servicecfg.Database{ - Name: "cassandra", - Protocol: defaults.ProtocolCassandra, - URI: "localhost:9042", - }, servicecfg.Database{ - Name: "snowflake", - Protocol: defaults.ProtocolSnowflake, - URI: "localhost.snowflakecomputing.com", - }, servicecfg.Database{ - Name: "mongo", - Protocol: defaults.ProtocolMongoDB, - URI: "localhost:27017", - }, servicecfg.Database{ - Name: "mssql", - Protocol: defaults.ProtocolSQLServer, - URI: "localhost:1433", - }, servicecfg.Database{ - Name: "dynamodb", - Protocol: defaults.ProtocolDynamoDB, - URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. - AWS: servicecfg.DatabaseAWS{ - AccountID: "123456789012", - ExternalID: "123123123", - Region: "us-west-1", - }, - }) - - authServer := authProcess.GetAuthServer() - require.NotNil(t, authServer) - - proxyAddr, err := proxyProcess.ProxyWebAddr() - require.NoError(t, err) + // set the public address to be sure even on v2+, tsh clients will see the separate port. + mySQLAddr := localListenerAddr() + cfg.Proxy.MySQLAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: mySQLAddr} + cfg.Proxy.MySQLPublicAddrs = []utils.NetAddr{{AddrNetwork: "tcp", Addr: mySQLAddr}} + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + { + Name: "postgres-local", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + "env": "local", + }, + }, { + Name: "postgres-rds-us-west-1-123456789012", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + types.DiscoveredNameLabel: "postgres", + "region": "us-west-1", + "env": "prod", + }, + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + Region: "us-west-1", + RDS: servicecfg.DatabaseAWSRDS{ + InstanceID: "postgres", + }, + }, + }, { + Name: "postgres-rds-us-west-2-123456789012", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + types.DiscoveredNameLabel: "postgres", + "region": "us-west-2", + "env": "prod", + }, + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + Region: "us-west-2", + RDS: servicecfg.DatabaseAWSRDS{ + InstanceID: "postgres", + }, + }, + }, { + Name: "mysql", + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + }, { + Name: "cassandra", + Protocol: defaults.ProtocolCassandra, + URI: "localhost:9042", + }, { + Name: "snowflake", + Protocol: defaults.ProtocolSnowflake, + URI: "localhost.snowflakecomputing.com", + }, { + Name: "mongo", + Protocol: defaults.ProtocolMongoDB, + URI: "localhost:27017", + }, { + Name: "mssql", + Protocol: defaults.ProtocolSQLServer, + URI: "localhost:1433", + }, { + Name: "dynamodb", + Protocol: defaults.ProtocolDynamoDB, + URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + ExternalID: "123123123", + Region: "us-west-1", + }, + }} + }), + ) + s.user = alice // Log into Teleport cluster. - err = Run(context.Background(), []string{ - "login", "--insecure", "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(), - }, setHomePath(tmpHomePath), CliOption(func(cf *CLIConf) error { - cf.MockSSOLogin = mockSSOLogin(t, authServer, alice) - return nil - })) - require.NoError(t, err) + tmpHomePath, _ := mustLogin(t, s) testCases := []struct { - databaseName string + // databaseName should be the full database name. + databaseName string + // dbSelectors can be any of db name, --labels, --query predicate, + // and defaults to be databaseName if not set. + dbSelectors []string expectCertsLen int expectKeysLen int expectErrForConfigCmd bool expectErrForEnvCmd bool }{ - { - databaseName: "postgres", - expectCertsLen: 1, - }, { databaseName: "mongo", expectCertsLen: 1, @@ -169,6 +207,24 @@ func TestDatabaseLogin(t *testing.T) { expectErrForConfigCmd: true, // "tsh db config" not supported for DynamoDB. expectErrForEnvCmd: true, // "tsh db env" not supported for DynamoDB. }, + { + databaseName: "postgres-local", + // select by labels alone. + dbSelectors: []string{"--labels", "env=local"}, + expectCertsLen: 1, + }, + { + databaseName: "postgres-rds-us-west-1-123456789012", + // select by query alone. + dbSelectors: []string{"--query", `labels.env=="prod" && labels.region == "us-west-1"`}, + expectCertsLen: 1, + }, + { + databaseName: "postgres-rds-us-west-2-123456789012", + // select by uniquely identifying prefix. + dbSelectors: []string{"postgres-rds-us-west-2"}, + expectCertsLen: 1, + }, } // Note: keystore currently races when multiple tsh clients work in the @@ -179,51 +235,72 @@ func TestDatabaseLogin(t *testing.T) { // Copying the profile dir is faster than sequential login for each database. for _, test := range testCases { test := test - t.Run(fmt.Sprintf("%v/%v", "tsh db login", test.databaseName), func(t *testing.T) { + t.Run(test.databaseName, func(t *testing.T) { t.Parallel() tmpHomePath := mustCloneTempDir(t, tmpHomePath) - err := Run(context.Background(), []string{ + selectors := test.dbSelectors + if len(selectors) == 0 { + selectors = []string{test.databaseName} + } + args := append([]string{ // default --db-user and --db-name are selected from roles. - "db", "login", test.databaseName, - }, setHomePath(tmpHomePath)) + "db", "login", + }, selectors...) + err := Run(context.Background(), args, setHomePath(tmpHomePath)) require.NoError(t, err) // Fetch the active profile. clientStore := client.NewFSClientStore(tmpHomePath) - profile, err := clientStore.ReadProfileStatus(proxyAddr.Host()) + profile, err := clientStore.ReadProfileStatus(s.root.Config.Proxy.WebAddr.String()) require.NoError(t, err) - require.Equal(t, alice.GetName(), profile.Username) + require.Equal(t, s.user.GetName(), profile.Username) // Verify certificates. + // grab the certs using the actual database name to verify certs. certs, keys, err := decodePEM(profile.DatabaseCertPathForCluster("", test.databaseName)) require.NoError(t, err) require.Equal(t, test.expectCertsLen, len(certs)) // don't use require.Len, because it spams PEM bytes on fail. require.Equal(t, test.expectKeysLen, len(keys)) // don't use require.Len, because it spams PEM bytes on fail. - t.Run(fmt.Sprintf("%v/%v", "tsh db config", test.databaseName), func(t *testing.T) { - t.Parallel() - err := Run(context.Background(), []string{ - "db", "config", test.databaseName, - }, setHomePath(tmpHomePath)) - - if test.expectErrForConfigCmd { - require.Error(t, err) - } else { - require.NoError(t, err) - } + t.Run("print info", func(t *testing.T) { + // organize these as parallel subtests in a group, so we can run + // them in parallel together before the logout test runs below. + t.Run("config", func(t *testing.T) { + t.Parallel() + args := append([]string{ + "db", "config", + }, selectors...) + err := Run(context.Background(), args, setHomePath(tmpHomePath)) + + if test.expectErrForConfigCmd { + require.Error(t, err) + require.NotContains(t, err.Error(), "matches multiple", "should not be ambiguity error") + } else { + require.NoError(t, err) + } + }) + t.Run("env", func(t *testing.T) { + t.Parallel() + args := append([]string{ + "db", "env", + }, selectors...) + err := Run(context.Background(), args, setHomePath(tmpHomePath)) + + if test.expectErrForEnvCmd { + require.Error(t, err) + require.NotContains(t, err.Error(), "matches multiple", "should not be ambiguity error") + } else { + require.NoError(t, err) + } + }) }) - t.Run(fmt.Sprintf("%v/%v", "tsh db env", test.databaseName), func(t *testing.T) { - t.Parallel() - err := Run(context.Background(), []string{ - "db", "env", test.databaseName, - }, setHomePath(tmpHomePath)) - - if test.expectErrForEnvCmd { - require.Error(t, err) - } else { - require.NoError(t, err) - } + t.Run("logout", func(t *testing.T) { + args := append([]string{ + "db", "logout", + }, selectors...) + err := Run(context.Background(), args, setHomePath(tmpHomePath)) + require.NoError(t, err) }) }) } @@ -340,22 +417,34 @@ func TestLocalProxyRequirement(t *testing.T) { } } -func TestListDatabase(t *testing.T) { - lib.SetInsecureDevMode(true) - defer lib.SetInsecureDevMode(false) - +func testListDatabase(t *testing.T) { + t.Parallel() + discoveredName := "root-postgres" + fullName := "root-postgres-rds-us-west-1-123456789012" s := newTestSuite(t, withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.StorageConfig.Params["poll_stream_period"] = 50 * time.Millisecond cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) cfg.Databases.Enabled = true cfg.Databases.Databases = []servicecfg.Database{{ - Name: "root-postgres", + Name: fullName, Protocol: defaults.ProtocolPostgres, URI: "localhost:5432", + StaticLabels: map[string]string{ + types.DiscoveredNameLabel: discoveredName, + }, + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + Region: "us-west-1", + RDS: servicecfg.DatabaseAWSRDS{ + InstanceID: "root-postgres", + }, + }, }} }), withLeafCluster(), withLeafConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.StorageConfig.Params["poll_stream_period"] = 50 * time.Millisecond cfg.Databases.Enabled = true cfg.Databases.Databases = []servicecfg.Database{{ Name: "leaf-postgres", @@ -365,7 +454,7 @@ func TestListDatabase(t *testing.T) { }), ) - mustLoginSetEnv(t, s) + tshHome, _ := mustLogin(t, s) captureStdout := new(bytes.Buffer) err := Run(context.Background(), []string{ @@ -373,10 +462,34 @@ func TestListDatabase(t *testing.T) { "ls", "--insecure", "--debug", - }, setCopyStdout(captureStdout)) + }, setCopyStdout(captureStdout), setHomePath(tshHome)) require.NoError(t, err) - require.Contains(t, captureStdout.String(), "root-postgres") + lines := strings.Split(captureStdout.String(), "\n") + require.Greater(t, len(lines), 2, + "there should be two lines of header followed by data rows") + require.True(t, + strings.HasPrefix(lines[2], discoveredName), + "non-verbose listing should print the discovered db name") + require.False(t, + strings.HasPrefix(lines[2], fullName), + "non-verbose listing should not print full db name") + + captureStdout.Reset() + err = Run(context.Background(), []string{ + "db", + "ls", + "--verbose", + "--insecure", + "--debug", + }, setCopyStdout(captureStdout), setHomePath(tshHome)) + require.NoError(t, err) + lines = strings.Split(captureStdout.String(), "\n") + require.Greater(t, len(lines), 2, + "there should be two lines of header followed by data rows") + require.True(t, + strings.HasPrefix(lines[2], fullName), + "verbose listing should print full db name") captureStdout.Reset() err = Run(context.Background(), []string{ @@ -386,7 +499,7 @@ func TestListDatabase(t *testing.T) { "leaf1", "--insecure", "--debug", - }, setCopyStdout(captureStdout)) + }, setCopyStdout(captureStdout), setHomePath(tshHome)) require.NoError(t, err) require.Contains(t, captureStdout.String(), "leaf-postgres") @@ -532,38 +645,6 @@ func TestDBInfoHasChanged(t *testing.T) { } } -func makeTestDatabaseServer(t *testing.T, auth *service.TeleportProcess, proxy *service.TeleportProcess, dbs ...servicecfg.Database) (db *service.TeleportProcess) { - // Proxy uses self-signed certificates in tests. - lib.SetInsecureDevMode(true) - - cfg := servicecfg.MakeDefaultConfig() - cfg.Hostname = "localhost" - cfg.DataDir = t.TempDir() - cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() - - proxyAddr, err := proxy.ProxyWebAddr() - require.NoError(t, err) - - cfg.SetAuthServerAddress(*proxyAddr) - - token, err := proxy.Config.Token() - require.NoError(t, err) - - cfg.SetToken(token) - cfg.SSH.Enabled = false - cfg.Auth.Enabled = false - cfg.Proxy.Enabled = false - cfg.Databases.Enabled = true - cfg.Databases.Databases = dbs - cfg.Log = utils.NewLoggerForTests() - - db = runTeleport(t, cfg) - - // Wait for all databases to register to avoid races. - waitForDatabases(t, auth, dbs) - return db -} - func waitForDatabases(t *testing.T, auth *service.TeleportProcess, dbs []servicecfg.Database) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -808,3 +889,184 @@ func TestGetDefaultDBNameAndUser(t *testing.T) { }) } } + +func testFilterActiveDatabases(t *testing.T) { + t.Parallel() + // setup some databases and "active" routes to test filtering + db1, route1 := makeDBConfigAndRoute("foobar", map[string]string{"env": "dev", "svc": "fooer"}) + db1AWS, route1AWS := makeDBConfigAndRoute("foobar-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) + db1Azure, route1Azure := makeDBConfigAndRoute("foobar-westus-11111", map[string]string{"env": "prod", "region": "westus"}) + db2, route2 := makeDBConfigAndRoute("bazqux", map[string]string{"env": "dev", "svc": "bazzer"}) + db2AWS, route2AWS := makeDBConfigAndRoute("bazqux-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) + db3, route3 := makeDBConfigAndRoute("some-unique-name", map[string]string{"env": "dev"}) + routes := []tlsca.RouteToDatabase{route1, route1AWS, route1Azure, route2, route2AWS, route3} + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + db1, db1AWS, db1Azure, db2, db2AWS, db3, + } + }), + ) + + // Log into Teleport cluster. + tmpHomePath, _ := mustLogin(t, s) + + tests := []struct { + name, + dbName, + labels, + query string + wantAPICall bool + wantRoutes []tlsca.RouteToDatabase + }{ + { + name: "by exact name", + dbName: route1.ServiceName, + wantAPICall: false, + wantRoutes: []tlsca.RouteToDatabase{route1}, + }, + { + name: "by name prefix", + dbName: "foo", + wantAPICall: false, + wantRoutes: []tlsca.RouteToDatabase{route1, route1AWS, route1Azure}, + }, + { + name: "by labels", + labels: "env=dev", + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1, route2, route3}, + }, + { + name: "by query", + query: `labels.env == "dev"`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1, route2, route3}, + }, + { + name: "by name prefix and labels", + dbName: "foo", + labels: "env=prod", + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1AWS, route1Azure}, + }, + { + name: "by name prefix and query", + dbName: "foo", + query: `labels.region == "us-west-1"`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1AWS}, + }, + { + name: "by labels and query", + labels: "env=dev", + query: `hasPrefix(name, "some-uniq")`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route3}, + }, + { + name: "by name prefix and labels and query", + dbName: "foo", + labels: "env=prod", + query: `labels.region == "westus"`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1Azure}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: tt.dbName, + Labels: tt.labels, + PredicateExpression: tt.query, + } + tc, err := makeClient(cf) + require.NoError(t, err) + routes, dbs, err := filterActiveDatabases(ctx, tc, routes) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tt.wantRoutes, routes)) + if tt.wantAPICall { + require.Equal(t, len(routes), len(dbs), + "returned routes should have corresponding types.Databases") + return + } + require.Zero(t, len(dbs), "unexpected API call to ListDatabases") + }) + } +} + +func TestResourceSelectorsFormatting(t *testing.T) { + tests := []struct { + testName string + selectors resourceSelectors + want string + }{ + { + testName: "no selectors", + selectors: resourceSelectors{ + kind: "database", + }, + want: "database", + }, + { + testName: "by name", + selectors: resourceSelectors{ + kind: "database", + name: "foo", + }, + want: `database "foo"`, + }, + { + testName: "by labels", + selectors: resourceSelectors{ + kind: "database", + labels: "env=dev,region=us-west-1", + }, + want: `database with labels "env=dev,region=us-west-1"`, + }, + { + testName: "by predicate", + selectors: resourceSelectors{ + kind: "database", + query: `labels["env"]=="dev" && labels.region == "us-west-1"`, + }, + want: `database with query (labels["env"]=="dev" && labels.region == "us-west-1")`, + }, + { + testName: "by name and labels and predicate", + selectors: resourceSelectors{ + kind: "app", + name: "foo", + labels: "env=dev,region=us-west-1", + query: `labels["env"]=="dev" && labels.region == "us-west-1"`, + }, + want: `app "foo" with labels "env=dev,region=us-west-1" with query (labels["env"]=="dev" && labels.region == "us-west-1")`, + }, + } + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + require.Equal(t, tt.want, fmt.Sprintf("%v", tt.selectors)) + }) + } +} + +// makeDBConfigAndRoute is a helper func that makes a db config and +// corresponding cert encoded route to that db - protocol etc not important. +func makeDBConfigAndRoute(name string, staticLabels map[string]string) (servicecfg.Database, tlsca.RouteToDatabase) { + db := servicecfg.Database{ + Name: name, + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: staticLabels, + } + route := tlsca.RouteToDatabase{ServiceName: name} + return db, route +} diff --git a/tool/tsh/common/kube.go b/tool/tsh/common/kube.go index 3ba8cac95b0e8..b51725e13e63d 100644 --- a/tool/tsh/common/kube.go +++ b/tool/tsh/common/kube.go @@ -968,7 +968,7 @@ func formatKubeLabels(cluster types.KubeCluster) string { func (c *kubeLSCommand) run(cf *CLIConf) error { cf.SearchKeywords = c.searchKeywords - cf.UserHost = c.labels + cf.Labels = c.labels cf.PredicateExpression = c.predicateExpr cf.SiteName = c.siteName diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index cd5537bedfd3f..0860e45423507 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -271,6 +271,11 @@ type CLIConf struct { // PredicateExpression defines boolean conditions that will be matched against the resource. PredicateExpression string + // Labels is used to hold labels passed via --labels=k1=v2,k2=v2,,, flag for resource filtering. + // explicitly passed --labels overrides user@labels positional arg form. + // NOTE: no command currently supports both, try to keep it that way. + Labels string + // NoRemoteExec will not execute a remote command after connecting to a host, // will block instead. Useful when port forwarding. Equivalent of -N for OpenSSH. NoRemoteExec bool @@ -730,7 +735,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { lsApps.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) lsApps.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) lsApps.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) - lsApps.Arg("labels", labelHelp).StringVar(&cf.UserHost) + lsApps.Arg("labels", labelHelp).StringVar(&cf.Labels) lsApps.Flag("all", "List apps from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll) appLogin := apps.Command("login", "Retrieve short-lived certificate for an app.") appLogin.Arg("app", "App name to retrieve credentials for. Can be obtained from `tsh apps ls` output.").Required().StringVar(&cf.AppName) @@ -765,7 +770,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { proxySSH.Arg("[user@]host", "Remote hostname and the login to use").Required().StringVar(&cf.UserHost) proxySSH.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) proxyDB := proxy.Command("db", "Start local TLS proxy for database connections when using Teleport in single-port mode.") - proxyDB.Arg("db", "The name of the database to start local proxy for").Required().StringVar(&cf.DatabaseService) + // don't require positional argument, user can select with --labels/--query alone. + proxyDB.Arg("db", "The name of the database to start local proxy for").StringVar(&cf.DatabaseService) proxyDB.Flag("port", "Specifies the source port used by proxy db listener").Short('p').StringVar(&cf.LocalProxyPort) // --cert-file and --key-file are deprecated in favor of --tunnel flag. proxyDB.Flag("cert-file", "Certificate file for proxy client TLS configuration").Hidden().StringVar(&cf.LocalProxyCertFile) @@ -774,6 +780,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { proxyDB.Flag("db-user", "Optional database user to log in as.").StringVar(&cf.DatabaseUser) proxyDB.Flag("db-name", "Optional database name to log in to.").StringVar(&cf.DatabaseName) proxyDB.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) + proxyDB.Flag("labels", labelHelp).StringVar(&cf.Labels) + proxyDB.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) proxyApp := proxy.Command("app", "Start local TLS proxy for app connection when using Teleport in single-port mode.") proxyApp.Arg("app", "The name of the application to start local proxy for").Required().StringVar(&cf.AppName) @@ -808,20 +816,29 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbList.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbList.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) dbList.Flag("all", "List databases from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll) - dbList.Arg("labels", labelHelp).StringVar(&cf.UserHost) + dbList.Arg("labels", labelHelp).StringVar(&cf.Labels) dbLogin := db.Command("login", "Retrieve credentials for a database.") - dbLogin.Arg("db", "Database to retrieve credentials for. Can be obtained from 'tsh db ls' output.").Required().StringVar(&cf.DatabaseService) + // don't require positional argument, user can select with --labels/--query alone. + dbLogin.Arg("db", "Database to retrieve credentials for. Can be obtained from 'tsh db ls' output.").StringVar(&cf.DatabaseService) + dbLogin.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbLogin.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbLogin.Flag("db-user", "Optional database user to configure as default.").StringVar(&cf.DatabaseUser) dbLogin.Flag("db-name", "Optional database name to configure as default.").StringVar(&cf.DatabaseName) dbLogout := db.Command("logout", "Remove database credentials.") dbLogout.Arg("db", "Database to remove credentials for.").StringVar(&cf.DatabaseService) + dbLogout.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbLogout.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbEnv := db.Command("env", "Print environment variables for the configured database.") - dbEnv.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) dbEnv.Arg("db", "Print environment for the specified database").StringVar(&cf.DatabaseService) + dbEnv.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) + dbEnv.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbEnv.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) // --db flag is deprecated in favor of positional argument for consistency with other commands. dbEnv.Flag("db", "Print environment for the specified database.").Hidden().StringVar(&cf.DatabaseService) dbConfig := db.Command("config", "Print database connection information. Useful when configuring GUI clients.") dbConfig.Arg("db", "Print information for the specified database.").StringVar(&cf.DatabaseService) + dbConfig.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbConfig.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) // --db flag is deprecated in favor of positional argument for consistency with other commands. dbConfig.Flag("db", "Print information for the specified database.").Hidden().StringVar(&cf.DatabaseService) dbConfig.Flag("format", fmt.Sprintf("Print format: %q to print in table format (default), %q to print connect command, %q or %q to print in JSON or YAML.", @@ -830,6 +847,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbConnect.Arg("db", "Database service name to connect to.").StringVar(&cf.DatabaseService) dbConnect.Flag("db-user", "Optional database user to log in as.").StringVar(&cf.DatabaseUser) dbConnect.Flag("db-name", "Optional database name to log in to.").StringVar(&cf.DatabaseName) + dbConnect.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) // join join := app.Command("join", "Join the active SSH or Kubernetes session.") @@ -861,7 +880,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { ls.Flag("format", defaults.FormatFlagDescription( teleport.Text, teleport.JSON, teleport.YAML, teleport.Names, )).Short('f').Default(teleport.Text).EnumVar(&cf.Format, teleport.Text, teleport.JSON, teleport.YAML, teleport.Names) - ls.Arg("labels", labelHelp).StringVar(&cf.UserHost) + ls.Arg("labels", labelHelp).StringVar(&cf.Labels) ls.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) ls.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) ls.Flag("all", "List nodes from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll) @@ -983,7 +1002,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { ).Required().EnumVar(&cf.ResourceKind, types.RequestableResourceKinds...) reqSearch.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) reqSearch.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) - reqSearch.Flag("labels", labelHelp).StringVar(&cf.UserHost) + reqSearch.Flag("labels", labelHelp).StringVar(&cf.Labels) reqSearch.Flag("kube-cluster", "Kubernetes Cluster to search for Pods").StringVar(&cf.KubernetesCluster) reqSearch.Flag("kube-namespace", "Kubernetes Namespace to search for Pods").Default(corev1.NamespaceDefault).StringVar(&cf.kubeNamespace) reqSearch.Flag("all-kube-namespaces", "Search Pods in every namespace").BoolVar(&cf.kubeAllNamespaces) @@ -2787,12 +2806,25 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec return fmt.Sprintf("%v, except: %v", dbUsers.Allowed, dbUsers.Denied) } +func getDiscoveredName(r types.ResourceWithLabels) (string, bool) { + name, ok := r.GetAllLabels()[types.DiscoveredNameLabel] + return name, ok +} + func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, active []tlsca.RouteToDatabase, accessChecker services.AccessChecker, verbose bool) []string { name := database.GetName() + printName := name + if d, ok := getDiscoveredName(database); ok && !verbose && d != name { + printName = d + } var connect string for _, a := range active { if a.ServiceName == name { - name = formatActiveDB(a) + a.ServiceName = printName + // format the db name with the print name + printName = formatActiveDB(a) + // then revert it for connect string + a.ServiceName = name switch a.Protocol { case defaults.ProtocolDynamoDB: // DynamoDB does not support "tsh db connect", so print the proxy command instead. @@ -2800,6 +2832,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, default: connect = formatDatabaseConnectCommand(clusterFlag, a) } + break } } @@ -2810,7 +2843,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, if verbose { row = append(row, - name, + printName, database.GetDescription(), database.GetProtocol(), database.GetType(), @@ -2821,7 +2854,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, ) } else { row = append(row, - name, + printName, database.GetDescription(), formatUsersForDB(database, accessChecker), formatDatabaseLabels(database), @@ -2878,8 +2911,9 @@ func printDatabasesWithClusters(clusterFlag string, dbListings []databaseListing func formatDatabaseLabels(database types.Database) string { labels := database.GetAllLabels() - // Hide the origin label unless printing verbose table. + // Hide the origin and discovered-name labels unless printing verbose table. delete(labels, types.OriginLabel) + delete(labels, types.DiscoveredNameLabel) return sortedLabels(labels) } @@ -3398,6 +3432,15 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err } } } + + // explicitly passed --labels overrides user@labels positional arg form. + if cf.Labels != "" { + labels, err = client.ParseLabelSpec(cf.Labels) + if err != nil { + return nil, trace.Wrap(err) + } + } + fPorts, err := client.ParsePortForwardSpec(cf.LocalForwardPorts) if err != nil { return nil, trace.Wrap(err) diff --git a/tool/tsh/common/tsh_helper_test.go b/tool/tsh/common/tsh_helper_test.go index f0ec2f62aec3c..7980a54e4cc75 100644 --- a/tool/tsh/common/tsh_helper_test.go +++ b/tool/tsh/common/tsh_helper_test.go @@ -91,12 +91,13 @@ func (s *suite) setupRootCluster(t *testing.T, options testSuiteOptions) { cfg.Proxy.DisableWebInterface = true cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ StaticTokens: []types.ProvisionTokenV1{{ - Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleNode, types.RoleTrustedCluster}, + Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp}, Expires: time.Now().Add(time.Minute), Token: staticToken, }}, }) require.NoError(t, err) + cfg.SetToken(staticToken) user, err := user.Current() require.NoError(t, err) @@ -134,7 +135,6 @@ func (s *suite) setupRootCluster(t *testing.T, options testSuiteOptions) { } s.root = runTeleport(t, cfg) - t.Cleanup(func() { require.NoError(t, s.root.Close()) }) } func (s *suite) setupLeafCluster(t *testing.T, options testSuiteOptions) { @@ -182,6 +182,15 @@ func (s *suite) setupLeafCluster(t *testing.T, options testSuiteOptions) { require.NoError(t, err) cfg.Proxy.DisableWebInterface = true + cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ + StaticTokens: []types.ProvisionTokenV1{{ + Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp}, + Expires: time.Now().Add(time.Minute), + Token: staticToken, + }}, + }) + require.NoError(t, err) + cfg.SetToken(staticToken) sshLoginRole, err := types.NewRole("ssh-login", types.RoleSpecV6{ Allow: types.RoleConditions{ Logins: []string{user.Username},