diff --git a/api/types/constants.go b/api/types/constants.go index 6c89bde714642..dfab169a4f875 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -533,6 +533,12 @@ const ( // See also TeleportNamespace and TeleportInternalLabelPrefix. TeleportHiddenLabelPrefix = "teleport.hidden/" + // DiscoveredNameLabel is a resource metadata label name used to identify + // the discovered name of a resource, i.e. the name of a resource before a + // uniquely distinguishing suffix is added by the discovery service. + // See: RFD 129 - Avoid Discovery Resource Name Collisions. + DiscoveredNameLabel = TeleportInternalLabelPrefix + "discovered-name" + // BotLabel is a label used to identify a resource used by a certificate renewal bot. BotLabel = TeleportInternalLabelPrefix + "bot" diff --git a/docs/pages/reference/predicate-language.mdx b/docs/pages/reference/predicate-language.mdx index 4d70344ecb2b3..954cbeab1a47d 100644 --- a/docs/pages/reference/predicate-language.mdx +++ b/docs/pages/reference/predicate-language.mdx @@ -68,6 +68,7 @@ The language also supports the following functions: | `exists(labels["env"])` | resources with a label key `env`; label value unchecked | | `!exists(labels["env"])` | resources without a label key `env`; label value unchecked | | `search("foo", "bar", "some phrase")` | fuzzy match against common resource fields | +| `hasPrefix(name, "foo")` | resources with a name that starts with the prefix `foo` | See some [examples](cli.mdx#filter-examples) of the different ways you can filter resources. diff --git a/lib/client/api.go b/lib/client/api.go index 04bf52192a10d..ac55f7f26ec89 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -375,6 +375,16 @@ type Config struct { // MockHeadlessLogin is used in tests for mocking the Headless login response. MockHeadlessLogin SSHLoginFunc + // OverrideMySQLOptionFilePath overrides the MySQL option file path to use. + // Useful in parallel tests so they don't all use the default path in the + // user home dir. + OverrideMySQLOptionFilePath string + + // OverridePostgresServiceFilePath overrides the Postgres service file path. + // Useful in parallel tests so they don't all use the default path in the + // user home dir. + OverridePostgresServiceFilePath string + // HomePath is where tsh stores profiles HomePath string diff --git a/lib/client/db/dbcmd/dbcmd.go b/lib/client/db/dbcmd/dbcmd.go index cea69d85b3c8e..05fbccc9030c2 100644 --- a/lib/client/db/dbcmd/dbcmd.go +++ b/lib/client/db/dbcmd/dbcmd.go @@ -330,7 +330,7 @@ func (c *CLICommandBuilder) getMySQLOracleCommand() (*exec.Cmd, error) { // We save configuration to ~/.my.cnf, but on Windows that file is not read, // see tables 4.1 and 4.2 on https://dev.mysql.com/doc/refman/8.0/en/option-files.html. // We instruct mysql client to use use that file with --defaults-extra-file. - configPath, err := mysql.DefaultConfigPath() + configPath, err := c.getMySQLOptionFilePath() if err != nil { return nil, trace.Wrap(err) } @@ -346,6 +346,15 @@ func (c *CLICommandBuilder) getMySQLOracleCommand() (*exec.Cmd, error) { return exec.Command(mysqlBin, args...), nil } +// getMySQLOptionFilePath gets the filepath to .my.cnf from the default location +// in ~/.my.cnf, unless overridden by config. +func (c *CLICommandBuilder) getMySQLOptionFilePath() (string, error) { + if c.tc.OverrideMySQLOptionFilePath != "" { + return c.tc.OverrideMySQLOptionFilePath, nil + } + return mysql.DefaultConfigPath() +} + // getMySQLCommand returns mariadb command if the binary is on the path. Otherwise, // mysql command is returned. Both mysql versions (MariaDB and Oracle) are supported. func (c *CLICommandBuilder) getMySQLCommand() (*exec.Cmd, error) { diff --git a/lib/client/db/mysql/optionfile.go b/lib/client/db/mysql/optionfile.go index 7bcee08192080..14499e8e02a8e 100644 --- a/lib/client/db/mysql/optionfile.go +++ b/lib/client/db/mysql/optionfile.go @@ -43,8 +43,10 @@ type OptionFile struct { path string } +// DefaultConfigPath returns the default config path, which is .my.cnf file in +// the user's home directory. Home dir is determined by environment if not +// supplied as an argument. func DefaultConfigPath() (string, error) { - // Default location is .my.cnf file in the user's home directory. home, err := os.UserHomeDir() if err != nil || home == "" { usr, err := utils.CurrentUser() diff --git a/lib/client/db/postgres/servicefile.go b/lib/client/db/postgres/servicefile.go index 6361c7d91f25c..ff95be5bde878 100644 --- a/lib/client/db/postgres/servicefile.go +++ b/lib/client/db/postgres/servicefile.go @@ -43,23 +43,33 @@ type ServiceFile struct { path string } -// Load loads Postgres connection service file from the default location. -func Load() (*ServiceFile, error) { +// DefaultConfigPath returns the default config path, which is .pg_service.conf +// file in the user's home directory. +func defaultConfigPath() (string, error) { // Default location is .pg_service.conf file in the user's home directory. // TODO(r0mant): Check PGSERVICEFILE and PGSYSCONFDIR env vars as well. home, err := os.UserHomeDir() if err != nil || home == "" { - user, err := utils.CurrentUser() + usr, err := utils.CurrentUser() if err != nil { - return nil, trace.ConvertSystemError(err) + return "", trace.ConvertSystemError(err) } - home = user.HomeDir + home = usr.HomeDir } - return LoadFromPath(filepath.Join(home, pgServiceFile)) + return filepath.Join(home, pgServiceFile), nil +} + +// Load loads Postgres connection service file from the default location. +func Load() (*ServiceFile, error) { + cnfPath, err := defaultConfigPath() + if err != nil { + return nil, trace.Wrap(err) + } + return LoadFromPath(cnfPath) } -// LoadFromPath loads Posrtgres connection service file from the specified path. +// LoadFromPath loads Postgres connection service file from the specified path. func LoadFromPath(path string) (*ServiceFile, error) { // Loose load will ignore file not found error. iniFile, err := ini.LooseLoad(path) diff --git a/lib/client/db/profile.go b/lib/client/db/profile.go index af63e4c1c2a49..36c2d31ae273b 100644 --- a/lib/client/db/profile.go +++ b/lib/client/db/profile.go @@ -45,7 +45,7 @@ func Add(ctx context.Context, tc *client.TeleportClient, db tlsca.RouteToDatabas if !IsSupported(db) { return nil } - profileFile, err := load(db) + profileFile, err := load(tc, db) if err != nil { return trace.Wrap(err) } @@ -98,7 +98,7 @@ func New(tc *client.TeleportClient, db tlsca.RouteToDatabase, clientProfile clie // Env returns environment variables for the specified database profile. func Env(tc *client.TeleportClient, db tlsca.RouteToDatabase) (map[string]string, error) { - profileFile, err := load(db) + profileFile, err := load(tc, db) if err != nil { return nil, trace.Wrap(err) } @@ -114,7 +114,7 @@ func Delete(tc *client.TeleportClient, db tlsca.RouteToDatabase) error { if !IsSupported(db) { return nil } - profileFile, err := load(db) + profileFile, err := load(tc, db) if err != nil { return trace.Wrap(err) } @@ -138,11 +138,17 @@ func IsSupported(db tlsca.RouteToDatabase) bool { } // load loads the appropriate database connection profile. -func load(db tlsca.RouteToDatabase) (profile.ConnectProfileFile, error) { +func load(tc *client.TeleportClient, db tlsca.RouteToDatabase) (profile.ConnectProfileFile, error) { switch db.Protocol { case defaults.ProtocolPostgres: + if tc.OverridePostgresServiceFilePath != "" { + return postgres.LoadFromPath(tc.OverridePostgresServiceFilePath) + } return postgres.Load() case defaults.ProtocolMySQL: + if tc.OverrideMySQLOptionFilePath != "" { + return mysql.LoadFromPath(tc.OverrideMySQLOptionFilePath) + } return mysql.Load() } return nil, trace.BadParameter("unsupported database protocol %q", diff --git a/lib/services/parser.go b/lib/services/parser.go index 863d290dcb134..d2fccdcb30129 100644 --- a/lib/services/parser.go +++ b/lib/services/parser.go @@ -741,6 +741,22 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, return predicate.Equals(a, b) } } + predPrefix := func(a interface{}, prefix string) predicate.BoolPredicate { + switch aval := a.(type) { + case label: + return func() bool { + return strings.HasPrefix(aval.value, prefix) + } + case string: + return func() bool { + return strings.HasPrefix(aval, prefix) + } + default: + return func() bool { + return false + } + } + } p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ @@ -753,7 +769,8 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, }, }, Functions: map[string]interface{}{ - "equals": predEquals, + "hasPrefix": predPrefix, + "equals": predEquals, // search allows fuzzy matching against select field values. "search": func(searchVals ...string) predicate.BoolPredicate { return func() bool { diff --git a/lib/services/parser_test.go b/lib/services/parser_test.go index 55a4f1ef999c5..7d1b258cddd43 100644 --- a/lib/services/parser_test.go +++ b/lib/services/parser_test.go @@ -192,6 +192,11 @@ func TestNewResourceParser(t *testing.T) { `search("os", "mac", "prod")`, `search()`, `!search("_")`, + // Test hasPrefix. + `hasPrefix(name, "")`, + `hasPrefix(name, "test-h")`, + `!hasPrefix(name, "foo")`, + `hasPrefix(resource.metadata.labels["env"], "pro")`, // Test exists. `exists(labels.env)`, `!exists(labels.undefined)`, @@ -206,6 +211,7 @@ func TestNewResourceParser(t *testing.T) { `labels.os == "mac" && name == "test-hostname" && search("v8")`, `exists(labels.env) && labels["env"] != "qa"`, `search("does", "not", "exist") || resource.spec.addr == "_" || labels.version == "v8"`, + `hasPrefix(labels.os, "m") && !hasPrefix(labels.env, "dev") && name == "test-hostname"`, // Test operator precedence `exists(labels.env) || (exists(labels.os) && labels.os != "mac")`, `exists(labels.env) || exists(labels.os) && labels.os != "mac"`, @@ -233,6 +239,7 @@ func TestNewResourceParser(t *testing.T) { `equals(resource.metadata.labels["env"], "wrong-value")`, `equals(resource.spec.hostname, "wrong-value")`, `search("mac", "not-found")`, + `hasPrefix(name, "x")`, } for _, expr := range exprs { t.Run(expr, func(t *testing.T) { @@ -269,6 +276,10 @@ func TestNewResourceParser(t *testing.T) { `exists(labels.env, "too", "many")`, `search(1,2)`, `"just-string"`, + `hasPrefix(1, 2)`, + `hasPrefix(name)`, + `hasPrefix(name, 1)`, + `hasPrefix(name, "too", "many")`, "", } for _, expr := range exprs { diff --git a/tool/teleport/testenv/test_server.go b/tool/teleport/testenv/test_server.go new file mode 100644 index 0000000000000..d2e92495e2934 --- /dev/null +++ b/tool/teleport/testenv/test_server.go @@ -0,0 +1,363 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package testenv provides functions for creating test servers for testing. +package testenv + +import ( + "context" + "crypto" + "fmt" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/breaker" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/tool/teleport/common" +) + +// used to easily join test services +const staticToken = "test-static-token" + +func init() { + // If the test is re-executing itself, execute the command that comes over + // the pipe. Used to test tsh ssh and tsh scp commands. + if srv.IsReexec() { + common.Run(common.Options{Args: os.Args[1:]}) + return + } + + modules.SetModules(&cliModules{}) +} + +// WithInsecureDevMode is a test helper that sets insecure dev mode and resets +// it in test cleanup. +// It is NOT SAFE to use in parallel tests, because it modifies a global. +// To run insecure dev mode tests in parallel, group them together under a +// parent test and then run them as parallel subtests. +// and call WithInsecureDevMode before running all the tests in parallel. +func WithInsecureDevMode(t *testing.T, mode bool) { + originalValue := lib.IsInsecureDevMode() + lib.SetInsecureDevMode(mode) + // To detect tests that run in parallel incorrectly, call t.Setenv with a + // dummy env var - that function detects tests with parallel ancestors + // and panics, preventing improper use of this helper. + t.Setenv("WithInsecureDevMode", "1") + t.Cleanup(func() { + lib.SetInsecureDevMode(originalValue) + }) +} + +// WithResyncInterval is a test helper that sets the tunnel resync interval and +// resets it in test cleanup. +// Useful to substantially speedup test cluster setup - passing 0 for the +// interval selects a reasonably fast default of 100ms. +// It is NOT SAFE to use in parallel tests, because it modifies a global. +func WithResyncInterval(t *testing.T, interval time.Duration) { + if interval == 0 { + interval = time.Millisecond * 100 + } + oldResyncInterval := defaults.ResyncInterval + defaults.ResyncInterval = interval + // To detect tests that run in parallel incorrectly, call t.Setenv with a + // dummy env var - that function detects tests with parallel ancestors + // and panics, preventing improper use of this helper. + t.Setenv("WithResyncInterval", "1") + t.Cleanup(func() { + defaults.ResyncInterval = oldResyncInterval + }) +} + +// MakeTestServer creates a Teleport Server for testing. +func MakeTestServer(t *testing.T, opts ...TestServerOptFunc) (process *service.TeleportProcess) { + t.Helper() + + var options TestServersOpts + for _, opt := range opts { + opt(&options) + } + + // Set up a test auth server with default config. + cfg := servicecfg.MakeDefaultConfig() + cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() + cfg.CachePolicy.Enabled = false + // Disables cloud auto-imported labels when running tests in cloud envs + // such as Github Actions. + // + // This is required otherwise Teleport will import cloud instance + // labels, and use them for example as labels in Kubernetes Service and + // cause some tests to fail because the output includes unexpected + // labels. + // + // It is also found that Azure metadata client can throw "Too many + // requests" during CI which fails services.NewTeleport. + cfg.InstanceMetadataClient = cloud.NewDisabledIMDSClient() + + cfg.Hostname = "server01" + cfg.DataDir = t.TempDir() + cfg.Log = utils.NewLoggerForTests() + authAddr := utils.NetAddr{AddrNetwork: "tcp", Addr: NewTCPListener(t, service.ListenerAuth, &cfg.FileDescriptors)} + cfg.SetToken(staticToken) + cfg.SetAuthServerAddress(authAddr) + + cfg.Auth.ListenAddr = authAddr + cfg.Auth.BootstrapResources = options.Bootstrap + cfg.Auth.StorageConfig.Params = backend.Params{defaults.BackendPath: filepath.Join(cfg.DataDir, defaults.BackendDir)} + staticToken, err := types.NewStaticTokens(types.StaticTokensSpecV2{ + StaticTokens: []types.ProvisionTokenV1{{ + Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp}, + Expires: time.Now().Add(time.Minute), + Token: staticToken, + }}, + }) + require.NoError(t, err) + cfg.Auth.StaticTokens = staticToken + + cfg.Proxy.WebAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: NewTCPListener(t, service.ListenerProxyWeb, &cfg.FileDescriptors)} + cfg.Proxy.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: NewTCPListener(t, service.ListenerProxySSH, &cfg.FileDescriptors)} + cfg.Proxy.ReverseTunnelListenAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: NewTCPListener(t, service.ListenerProxyTunnel, &cfg.FileDescriptors)} + cfg.Proxy.DisableWebInterface = true + + cfg.SSH.Addr = utils.NetAddr{AddrNetwork: "tcp", Addr: NewTCPListener(t, service.ListenerNodeSSH, &cfg.FileDescriptors)} + cfg.SSH.DisableCreateHostUser = true + + // Apply options + for _, fn := range options.ConfigFuncs { + fn(cfg) + } + + process, err = service.NewTeleport(cfg) + require.NoError(t, err, trace.DebugReport(err)) + require.NoError(t, process.Start()) + t.Cleanup(func() { + require.NoError(t, process.Close()) + require.NoError(t, process.Wait()) + }) + + waitForServices(t, process, cfg) + + return process +} + +// NewTCPListener creates a new TCP listener on 127.0.0.1:0, adds it to the +// FileDescriptor slice (with the specified type) and returns its actual local +// address as a string (for use in configuration). Takes a pointer to the slice +// so that it's convenient to call in the middle of a FileConfig or Config +// struct literal. +func NewTCPListener(t *testing.T, lt service.ListenerType, fds *[]servicecfg.FileDescriptor) string { + t.Helper() + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + addr := l.Addr().String() + + // File() returns a dup of the listener's file descriptor as an *os.File, so + // the original net.Listener still needs to be closed. + lf, err := l.(*net.TCPListener).File() + require.NoError(t, err) + // If the file descriptor slice ends up being passed to a TeleportProcess + // that successfully starts, listeners will either get "imported" and used + // or discarded and closed, this is just an extra safety measure that closes + // the listener at the end of the test anyway (the finalizer would do that + // anyway, in principle). + t.Cleanup(func() { lf.Close() }) + + *fds = append(*fds, servicecfg.FileDescriptor{ + Type: string(lt), + Address: addr, + File: lf, + }) + + return addr +} + +func waitForServices(t *testing.T, auth *service.TeleportProcess, cfg *servicecfg.Config) { + var serviceReadyEvents []string + if cfg.Proxy.Enabled { + serviceReadyEvents = append(serviceReadyEvents, service.ProxyWebServerReady) + } + if cfg.SSH.Enabled { + serviceReadyEvents = append(serviceReadyEvents, service.NodeSSHReady) + } + if cfg.Databases.Enabled { + serviceReadyEvents = append(serviceReadyEvents, service.DatabasesReady) + } + if cfg.Apps.Enabled { + serviceReadyEvents = append(serviceReadyEvents, service.AppsReady) + } + if cfg.Auth.Enabled { + serviceReadyEvents = append(serviceReadyEvents, service.AuthTLSReady) + } + waitForEvents(t, auth, serviceReadyEvents...) + + if cfg.Auth.Enabled && cfg.Databases.Enabled { + waitForDatabases(t, auth, cfg.Databases.Databases) + } +} + +func waitForEvents(t *testing.T, svc service.Supervisor, events ...string) { + for _, event := range events { + _, err := svc.WaitForEventTimeout(10*time.Second, event) + require.NoError(t, err, "service server didn't receive %v event after 10s", event) + } +} + +func waitForDatabases(t *testing.T, auth *service.TeleportProcess, dbs []servicecfg.Database) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for { + select { + case <-time.After(500 * time.Millisecond): + all, err := auth.GetAuthServer().GetDatabaseServers(ctx, apidefaults.Namespace) + require.NoError(t, err) + + // Count how many input "dbs" are registered. + var registered int + for _, db := range dbs { + for _, a := range all { + if a.GetName() == db.Name { + registered++ + break + } + } + } + + if registered == len(dbs) { + return + } + case <-ctx.Done(): + t.Fatal("Databases not registered after 10s") + } + } +} + +type TestServersOpts struct { + Bootstrap []types.Resource + ConfigFuncs []func(cfg *servicecfg.Config) +} + +type TestServerOptFunc func(o *TestServersOpts) + +func WithBootstrap(bootstrap ...types.Resource) TestServerOptFunc { + return func(o *TestServersOpts) { + o.Bootstrap = bootstrap + } +} + +func WithConfig(fn func(cfg *servicecfg.Config)) TestServerOptFunc { + return func(o *TestServersOpts) { + o.ConfigFuncs = append(o.ConfigFuncs, fn) + } +} + +func WithAuthConfig(fn func(*servicecfg.AuthConfig)) TestServerOptFunc { + return WithConfig(func(cfg *servicecfg.Config) { + fn(&cfg.Auth) + }) +} + +func WithClusterName(t *testing.T, n string) TestServerOptFunc { + return WithAuthConfig(func(cfg *servicecfg.AuthConfig) { + clusterName, err := services.NewClusterNameWithRandomID( + types.ClusterNameSpecV2{ + ClusterName: n, + }) + require.NoError(t, err) + cfg.ClusterName = clusterName + }) +} + +func WithHostname(hostname string) TestServerOptFunc { + return WithConfig(func(cfg *servicecfg.Config) { + cfg.Hostname = hostname + }) +} + +func WithSSHPublicAddrs(addrs ...string) TestServerOptFunc { + return WithConfig(func(cfg *servicecfg.Config) { + cfg.SSH.PublicAddrs = utils.MustParseAddrList(addrs...) + }) +} + +func WithSSHLabel(key, value string) TestServerOptFunc { + return WithConfig(func(cfg *servicecfg.Config) { + if cfg.SSH.Labels == nil { + cfg.SSH.Labels = make(map[string]string) + } + cfg.SSH.Labels[key] = value + }) +} + +type cliModules struct{} + +// BuildType returns build type. +func (p *cliModules) BuildType() string { + return "CLI" +} + +// PrintVersion prints the Teleport version. +func (p *cliModules) PrintVersion() { + fmt.Println("Teleport CLI") +} + +// Features returns supported features +func (p *cliModules) Features() modules.Features { + return modules.Features{ + Kubernetes: true, + DB: true, + App: true, + AdvancedAccessWorkflows: true, + AccessControls: true, + } +} + +// IsBoringBinary checks if the binary was compiled with BoringCrypto. +func (p *cliModules) IsBoringBinary() bool { + return false +} + +// AttestHardwareKey attests a hardware key. +func (p *cliModules) AttestHardwareKey(_ context.Context, _ interface{}, _ keys.PrivateKeyPolicy, _ *keys.AttestationStatement, _ crypto.PublicKey, _ time.Duration) (keys.PrivateKeyPolicy, error) { + return keys.PrivateKeyPolicyNone, nil +} + +func (p *cliModules) EnableRecoveryCodes() { +} + +func (p *cliModules) EnablePlugins() { +} + +func (p *cliModules) SetFeatures(f modules.Features) { +} diff --git a/tool/tsh/db.go b/tool/tsh/db.go index 395e458b41cdf..9ba41eda6e7da 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/gravitational/teleport" @@ -250,11 +251,7 @@ func onDatabaseLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - dbInfo, err := newDatabaseInfo(cf, tc, tlsca.RouteToDatabase{ - ServiceName: cf.DatabaseService, - Username: cf.DatabaseUser, - Database: cf.DatabaseName, - }) + dbInfo, err := newDatabaseInfo(cf, tc, nil) if err != nil { return trace.Wrap(err) } @@ -351,7 +348,11 @@ func onDatabaseLogout(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - activeDatabases, err := profile.DatabasesForCluster(tc.SiteName) + activeRoutes, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + logout, _, err := filterActiveDatabases(cf.Context, tc, activeRoutes) if err != nil { return trace.Wrap(err) } @@ -360,34 +361,46 @@ func onDatabaseLogout(cf *CLIConf) error { log.Info("Note: an identity file is in use (`-i ...`); will only update database config files.") } - var logout []tlsca.RouteToDatabase - // If database name wasn't given on the command line, log out of all. - if cf.DatabaseService == "" { - logout = activeDatabases - } else { - for _, db := range activeDatabases { - if db.ServiceName == cf.DatabaseService { - logout = append(logout, db) - } - } - if len(logout) == 0 { - return trace.BadParameter("Not logged into database %q", - tc.DatabaseService) - } - } for _, db := range logout { if err := databaseLogout(tc, db, profile.IsVirtual); err != nil { return trace.Wrap(err) } } - if len(logout) == 1 { - fmt.Println("Logged out of database", logout[0].ServiceName) - } else { - fmt.Println("Logged out of all databases") + msg, err := makeLogoutMessage(cf, logout, activeRoutes) + if err != nil { + return trace.Wrap(err) } + fmt.Fprintln(cf.Stdout(), msg) return nil } +// makeLogoutMessage is a helper func that returns a logout message for the +// result of "tsh db logout". +func makeLogoutMessage(cf *CLIConf, logout, activeRoutes []tlsca.RouteToDatabase) (string, error) { + switch len(logout) { + case 0: + selectors := resourceSelectors{ + kind: "database", + name: cf.DatabaseService, + labels: cf.Labels, + query: cf.PredicateExpression, + } + return "", trace.NotFound("Not logged into %v", selectors) + case 1: + return fmt.Sprintf("Logged out of database %v", logout[0].ServiceName), nil + case len(activeRoutes): + return "Logged out of all databases", nil + default: + names := make([]string, 0, len(logout)) + for _, route := range logout { + names = append(names, route.ServiceName) + } + slices.Sort(names) + nameLines := strings.Join(names, "\n") + return fmt.Sprintf("Logged out of databases:\n%v", nameLines), nil + } +} + func databaseLogout(tc *client.TeleportClient, db tlsca.RouteToDatabase, virtual bool) error { // First remove respective connection profile. err := dbprofile.Delete(tc, db) @@ -413,7 +426,7 @@ func onDatabaseEnv(cf *CLIConf) error { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf) + database, err := pickActiveDatabase(cf, tc) if err != nil { return trace.Wrap(err) } @@ -471,7 +484,7 @@ func onDatabaseConfig(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf) + database, err := pickActiveDatabase(cf, tc) if err != nil { return trace.Wrap(err) } @@ -790,18 +803,95 @@ func onDatabaseConnect(cf *CLIConf) error { } // getDatabaseInfo fetches information about the database from tsh profile if DB -// is active in profile. Otherwise, the ListDatabases endpoint is called. +// is active in profile and no labels or predicate query are given. +// Otherwise, the ListDatabases endpoint is called. func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { - if route, err := pickActiveDatabase(cf); err == nil { - return newDatabaseInfo(cf, tc, *route) - } else if err != nil && !trace.IsNotFound(err) { + haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != "" + if !haveSelectors { + // if selectors are given, we might incur an extra ListDatabases API + // call here to match against an active database. + // So try to pick an active database only when we don't have + // selectors. + if route, err := pickActiveDatabase(cf, tc); err == nil { + return newDatabaseInfo(cf, tc, route) + } else if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + } + return newDatabaseInfo(cf, tc, nil) +} + +// newDatabaseInfo makes a new databaseInfo from the given route to the db. +// It checks the route and sets defaults as needed for protocol, db user, or db +// name. If the route is not given or the remote database is needed for setting +// a default, the database is retrieved by calling ListDatabases API and cached. +func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteToDatabase) (*databaseInfo, error) { + dbInfo := &databaseInfo{} + if route != nil { + dbInfo.RouteToDatabase = *route + // the only way we're going to have all this info populated is from an + // active cert. + if dbInfo.ServiceName != "" && dbInfo.Protocol != "" && + dbInfo.Username != "" && dbInfo.Database != "" { + return dbInfo, nil + } + } + db, err := dbInfo.GetDatabase(cf, tc) + if err != nil { return nil, trace.Wrap(err) } - return newDatabaseInfo(cf, tc, tlsca.RouteToDatabase{ - ServiceName: cf.DatabaseService, - Username: cf.DatabaseUser, - Database: cf.DatabaseName, - }) + // now ensure the route name and protocol matches the db we fetched. + dbInfo.ServiceName = db.GetName() + dbInfo.Protocol = db.GetProtocol() + return dbInfo, dbInfo.checkAndSetPrincipalDefaults(cf, tc, db) +} + +// checkAndSetPrincipalDefaults checks the db route (schema) name and username, +// and sets them to defaults if necessary. +func (d *databaseInfo) checkAndSetPrincipalDefaults(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { + if cf.DatabaseUser != "" { + d.Username = cf.DatabaseUser + } + if cf.DatabaseName != "" { + d.Database = cf.DatabaseName + } + // If database has admin user defined, we're most likely using automatic + // user provisioning so default to Teleport username unless database + // username was provided explicitly. + if d.Username == "" && db.GetAdminUser() != "" { + log.Debugf("Defaulting to Teleport username %q as database username.", tc.Username) + d.Username = tc.Username + } + if d.Username == "" { + switch d.Protocol { + // When generating certificate for MongoDB access, database username must + // be encoded into it. This is required to be able to tell which database + // user to authenticate the connection as Elasticsearch needs database username too. + case defaults.ProtocolMongoDB, defaults.ProtocolElasticsearch, defaults.ProtocolOracle, defaults.ProtocolOpenSearch: + return trace.BadParameter("please provide the database user name using the --db-user flag") + case defaults.ProtocolRedis: + // Default to "default" in the same way as Redis does. We need the username to check access on our side. + // ref: https://redis.io/commands/auth + log.Debugf("Defaulting to Redis username %q as database username.", defaults.DefaultRedisUsername) + d.Username = defaults.DefaultRedisUsername + } + } + + if d.Database != "" { + switch d.Protocol { + case defaults.ProtocolDynamoDB: + log.Warnf("Database %v protocol %v does not support --db-name flag, ignoring --db-name=%v", + d.ServiceName, defaults.ReadableDatabaseProtocol(d.Protocol), d.Database) + d.Database = "" + } + } else { + switch d.Protocol { + // Always require db-name for Oracle Protocol. + case defaults.ProtocolOracle: + return trace.BadParameter("please provide the database name using the --db-name flag") + } + } + return nil } // databaseInfo wraps a RouteToDatabase and the corresponding database. @@ -818,87 +908,110 @@ type databaseInfo struct { // GetDatabase returns the cached database or fetches it using the db route and // caches the result. func (d *databaseInfo) GetDatabase(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { + if d.ServiceName == "" && cf.DatabaseService == "" && + len(tc.Labels) == 0 && tc.PredicateExpression == "" { + return nil, trace.BadParameter("specify a database service by name, --labels, or --query") + } d.mu.Lock() defer d.mu.Unlock() if d.database != nil { return d.database, nil } - var databases []types.Database // holding mutex across the api call to avoid multiple redundant api calls. - err := client.RetryWithRelogin(cf.Context, tc, func() error { - var err error - databases, err = tc.ListDatabases(cf.Context, &proto.ListResourcesRequest{ - Namespace: tc.Namespace, - ResourceType: types.KindDatabaseServer, - PredicateExpression: fmt.Sprintf(`name == "%s"`, d.ServiceName), - }) - return trace.Wrap(err) - }) + var databases types.Databases + var err error + name := d.ServiceName + if name != "" { + databases, err = listDatabasesByName(cf.Context, tc, name) + } else { + name = cf.DatabaseService + // search by prefix if the db name comes from cli flag instead of cert. + databases, err = listDatabasesByPrefix(cf.Context, tc, name) + } if err != nil { return nil, trace.Wrap(err) } - if len(databases) == 0 { - return nil, trace.NotFound( - "database %q not found, use '%v' to see registered databases", - d.ServiceName, formatDatabaseListCommand(cf.SiteName)) + if len(databases) != 1 { + // error - we need exactly one database. + selectors := resourceSelectors{ + kind: "database", + name: name, + labels: cf.Labels, + query: cf.PredicateExpression, + } + if len(databases) == 0 { + return nil, trace.NotFound( + "%v not found, use '%v' to see registered databases", selectors, + formatDatabaseListCommand(cf.SiteName)) + } + errMsg := formatAmbiguousDB(cf, selectors, databases) + return nil, trace.BadParameter(errMsg) } + d.database = databases[0] return d.database, nil } -// newDatabaseInfo makes a new databaseInfo from the given route to the db. -// It checks the route and sets defaults as needed for protocol, db user, or db -// name. If the remote database is needed for setting a default, it is retrieved -// by calling ListDatabases API and cached. -func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase) (*databaseInfo, error) { - dbInfo := databaseInfo{RouteToDatabase: route} - if dbInfo.ServiceName == "" { - return nil, trace.BadParameter("missing database service name") - } - if dbInfo.Protocol != "" && dbInfo.Username != "" && dbInfo.Database != "" { - return &dbInfo, nil +// listActiveDatabases lists databases that match active (logged in) databases. +func listActiveDatabases(ctx context.Context, tc *client.TeleportClient, routes []tlsca.RouteToDatabase) (types.Databases, error) { + names := make([]string, 0, len(routes)) + for _, r := range routes { + names = append(names, fmt.Sprintf("(name == %q)", r.ServiceName)) } - db, err := dbInfo.GetDatabase(cf, tc) + predicate := strings.Join(names, "||") + return listDatabasesWithPredicate(ctx, tc, predicate) +} + +// listDatabasesByName lists database that match a given name. +func listDatabasesByName(ctx context.Context, tc *client.TeleportClient, name string) (types.Databases, error) { + predicate := fmt.Sprintf("name == %s", name) + return listDatabasesWithPredicate(ctx, tc, predicate) +} + +// listDatabasesByPrefix lists databases that match a given name prefix. +func listDatabasesByPrefix(ctx context.Context, tc *client.TeleportClient, prefix string) (types.Databases, error) { + predicate := fmt.Sprintf(`hasPrefix(name, "%s")`, prefix) + databases, err := listDatabasesWithPredicate(ctx, tc, predicate) + if err == nil || !utils.IsPredicateError(err) { + return databases, trace.Wrap(err) + } + // predicate error from using hasPrefix expression. + // fallback to listing without the hasPrefix predicate and filtering + // on client side for backwards compatibility. + databases, err = listDatabasesWithPredicate(ctx, tc, "") if err != nil { return nil, trace.Wrap(err) } - dbInfo.Protocol = db.GetProtocol() - // If database has admin user defined, we're most likely using automatic - // user provisioning so default to Teleport username unless database - // username was provided explicitly. - if dbInfo.Username == "" && db.GetAdminUser() != "" { - log.Debugf("Defaulting to Teleport username %q as database username.", tc.Username) - dbInfo.Username = tc.Username - } - if dbInfo.Username == "" { - switch dbInfo.Protocol { - // When generating certificate for MongoDB access, database username must - // be encoded into it. This is required to be able to tell which database - // user to authenticate the connection as Elasticsearch needs database username too. - case defaults.ProtocolMongoDB, defaults.ProtocolElasticsearch, defaults.ProtocolOracle, defaults.ProtocolOpenSearch: - return nil, trace.BadParameter("please provide the database user name using the --db-user flag") - case defaults.ProtocolRedis: - // Default to "default" in the same way as Redis does. We need the username to check access on our side. - // ref: https://redis.io/commands/auth - log.Debugf("Defaulting to Redis username %q as database username.", defaults.DefaultRedisUsername) - dbInfo.Username = defaults.DefaultRedisUsername + var out types.Databases + for _, db := range databases { + if strings.HasPrefix(db.GetName(), prefix) { + out = append(out, db) } } - if dbInfo.Database != "" { - switch dbInfo.Protocol { - case defaults.ProtocolDynamoDB: - log.Warnf("Database %v protocol %v does not support --db-name flag, ignoring --db-name=%v", - dbInfo.ServiceName, defaults.ReadableDatabaseProtocol(dbInfo.Protocol), dbInfo.Database) - dbInfo.Database = "" - } - } else { - switch dbInfo.Protocol { - // Always require db-name for Oracle Protocol. - case defaults.ProtocolOracle: - return nil, trace.BadParameter("please provide the database name using the --db-name flag") - } + return out, nil +} + +// listDatabasesWithPredicate is a helper func for listing databases using +// a given additional predicate expression. If the teleport client already +// has a predicate expression, the predicates are combined with a logical AND. +func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, predicate string) (types.Databases, error) { + if predicate == "" { + predicate = tc.PredicateExpression + } else if tc.PredicateExpression != "" { + predicate = fmt.Sprintf("(%v) && (%v)", predicate, tc.PredicateExpression) } - return &dbInfo, nil + var databases []types.Database + err := client.RetryWithRelogin(ctx, tc, func() error { + var err error + databases, err = tc.ListDatabases(ctx, &proto.ListResourcesRequest{ + Namespace: tc.Namespace, + ResourceType: types.KindDatabaseServer, + PredicateExpression: predicate, + Labels: tc.Labels, + }) + return trace.Wrap(err) + }) + return databases, trace.Wrap(err) } func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase, profile *client.ProfileStatus, requires *dbLocalProxyRequirement) (bool, error) { @@ -1023,47 +1136,109 @@ func isMFADatabaseAccessRequired(ctx context.Context, tc *client.TeleportClient, // // If logged into multiple databases, returns an error unless one specified // explicitly via --db flag. -func pickActiveDatabase(cf *CLIConf) (*tlsca.RouteToDatabase, error) { - profile, err := cf.ProfileStatus() +func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToDatabase, error) { + profile, err := tc.ProfileStatus() if err != nil { return nil, trace.Wrap(err) } - activeDatabases, err := profile.DatabasesForCluster(cf.SiteName) + routes, err := profile.DatabasesForCluster(tc.SiteName) if err != nil { return nil, trace.Wrap(err) } - if len(activeDatabases) == 0 { + if len(routes) == 0 { return nil, trace.NotFound("please login using 'tsh db login' first") } - name := cf.DatabaseService - if name == "" { - if len(activeDatabases) > 1 { - var services []string - for _, database := range activeDatabases { - services = append(services, database.ServiceName) - } - return nil, trace.BadParameter("Multiple databases are available (%v), please specify one using CLI argument", - strings.Join(services, ", ")) + routes, databases, err := filterActiveDatabases(cf.Context, tc, routes) + if err != nil { + return nil, trace.Wrap(err) + } + + if len(routes) != 1 { + // error - we need exactly one route. + selectors := resourceSelectors{ + kind: "database", + name: cf.DatabaseService, + labels: cf.Labels, + query: cf.PredicateExpression, } - name = activeDatabases[0].ServiceName - } - for _, db := range activeDatabases { - if db.ServiceName == name { - // If database user or name were provided on the CLI, - // override the default ones. - if cf.DatabaseUser != "" { - db.Username = cf.DatabaseUser + if len(routes) == 0 { + return nil, trace.NotFound("not logged into %v", selectors) + } + if len(databases) == 0 { + // if not already given, try to fetch them so we can print full + // the full `tsh db ls -v` table of ambiguously matching active DBs. + databases, err = listActiveDatabases(cf.Context, tc, routes) + if err != nil { + return nil, trace.Wrap(err) } - if cf.DatabaseName != "" { - db.Database = cf.DatabaseName + } + errMsg := formatAmbiguousDB(cf, selectors, databases) + return nil, trace.BadParameter(errMsg) + } + + route := &routes[0] + // If database user or name were provided on the CLI, + // override the default ones. + if cf.DatabaseUser != "" { + route.Username = cf.DatabaseUser + } + if cf.DatabaseName != "" { + route.Database = cf.DatabaseName + } + return route, nil +} + +// filterActiveDatabases takes a list of active database routes and returns a +// filtered list and, possibly, their corresponding types.Databases. +// Callers should therefore not assume that the types.Databases are populated. +// Filtering is done by matching on database name, label, and query predicate +// selectors from the Teleport client. +// When only database name is given, filtering is done by name prefix, unless +// an active database name matches exactly, in which case all other active +// databases are filtered out - this is to avoid requiring additional selectors +// when a user gives an exact database name. +func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, types.Databases, error) { + prefix := tc.DatabaseService + if prefix == "" && len(activeRoutes) == 1 { + prefix = activeRoutes[0].ServiceName + } + + haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != "" + var selectedRoutes []tlsca.RouteToDatabase + for _, db := range activeRoutes { + if db.ServiceName == prefix && !haveSelectors { + // short-circuit to select the exact match when we don't have + // label or predicate selectors. + return []tlsca.RouteToDatabase{db}, nil, nil + } + if strings.HasPrefix(db.ServiceName, prefix) { + selectedRoutes = append(selectedRoutes, db) + } + } + if len(selectedRoutes) == 0 || !haveSelectors { + // nothing to filter further, avoid making API call. + return selectedRoutes, nil, nil + } + + // make a ListDatabases API call and match on full database name. + databases, err := listDatabasesByPrefix(ctx, tc, prefix) + if err != nil { + return nil, nil, trace.Wrap(err) + } + selectedRoutes = nil + var activeDBs types.Databases + for _, route := range activeRoutes { + for _, db := range databases { + if db.GetName() == route.ServiceName { + selectedRoutes = append(selectedRoutes, route) + activeDBs = append(activeDBs, db) } - return &db, nil } } - return nil, trace.NotFound("Not logged into database %q", name) + return selectedRoutes, activeDBs, nil } func formatDatabaseListCommand(clusterFlag string) string { @@ -1274,6 +1449,68 @@ func getDbCmdAlternatives(clusterFlag string, route tlsca.RouteToDatabase) []str return alts } +// formatAmbiguousDB is a helper func that formats an ambiguous database error +// message. +func formatAmbiguousDB(cf *CLIConf, selectors resourceSelectors, matchedDBs types.Databases) string { + var activeDBs []tlsca.RouteToDatabase + if profile, err := cf.ProfileStatus(); err == nil { + if dbs, err := profile.DatabasesForCluster(cf.SiteName); err == nil { + activeDBs = dbs + } + } + // Pass a nil access checker to avoid making a proxy roundtrip. + // Access info isn't relevant to an ambiguity error anyway. + var checker services.AccessChecker + var sb strings.Builder + verbose := true + showDatabasesAsText(&sb, cf.SiteName, matchedDBs, activeDBs, checker, verbose) + + listCommand := formatDatabaseListCommand(cf.SiteName) + return formatAmbiguityErrTemplate(cf, selectors, listCommand, sb.String()) +} + +// resourceSelectors is a helper struct for gathering up the selectors for a +// resource, as an aggregate of name, labels, and predicate query. +type resourceSelectors struct { + kind, + name, + labels, + query string +} + +// String returns the resource selectors as a formatted string. +// Example: +// command: `tsh db connect foo --labels k1=v1 --query 'labels["k2"]=="v2"'` +// output: database "foo" with labels "k1=v1" with query (labels["k2"]=="v2") +func (r resourceSelectors) String() string { + out := r.kind + if r.name != "" { + out = fmt.Sprintf("%s %q", out, r.name) + } + if len(r.labels) > 0 { + out = fmt.Sprintf("%s with labels %q", out, r.labels) + } + if len(r.query) > 0 { + out = fmt.Sprintf("%s with query (%s)", out, r.query) + } + return strings.TrimSpace(out) +} + +// formatAmbiguityErrTemplate is a helper func that formats an ambiguous +// resource error message. +func formatAmbiguityErrTemplate(cf *CLIConf, selectors resourceSelectors, listCommand, matchTable string) string { + data := map[string]any{ + "command": cf.CommandWithBinary(), + "selectors": strings.TrimSpace(selectors.String()), + "listCommand": strings.TrimSpace(listCommand), + "kind": strings.TrimSpace(selectors.kind), + "matchTable": strings.TrimSpace(matchTable), + } + var sb strings.Builder + _ = ambiguityErrTemplate.Execute(&sb, data) + return sb.String() +} + const ( // dbFormatText prints database configuration in text format. dbFormatText = "text" @@ -1301,9 +1538,7 @@ Please use one of the following commands to connect to the database: {{- range .alternatives}} {{.}}{{end -}} {{- end}}`)) -) -var ( // dbConnectTemplate is the message printed after a successful "tsh db login" on how to connect. dbConnectTemplate = template.Must(template.New("").Parse(`Connection information for database "{{ .name }}" has been saved. @@ -1328,5 +1563,16 @@ You can start a local proxy for database GUI clients: {{ .proxyCommand }} {{end -}} +`)) + + // ambiguityErrTemplate is the error message printed when a resource is + // specified ambiguously by name prefix and/or labels. + ambiguityErrTemplate = template.Must(template.New("").Parse("{{ .selectors }} matches multiple {{ .kind }}s:" + ` + +{{ .matchTable }} + +Hint: use '{{ .listCommand }} -v' or '{{ .listCommand }} --format=[json|yaml]' to list all {{ .kind }}s with full details. +Hint: try selecting the {{ .kind }} with a more specific name (ex: {{ .command }} full-{{ .kind }}-name). +Hint: try selecting the {{ .kind }} with additional --labels or --query predicate. `)) ) diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index d6e03ef8df04e..36b446b318ffa 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -23,21 +23,20 @@ import ( "crypto/rsa" "encoding/pem" "fmt" - "net" "os" "path/filepath" + "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" @@ -46,89 +45,128 @@ import ( "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/tool/teleport/testenv" ) -// TestDatabaseLogin tests "tsh db login" command and verifies "tsh db -// env/config" after login. -func TestDatabaseLogin(t *testing.T) { - tmpHomePath := t.TempDir() - - connector := mockConnector(t) +func TestTshDB(t *testing.T) { + // this speeds up test suite setup substantially, which is where + // tests spend the majority of their time, especially when leaf + // clusters are setup. + testenv.WithResyncInterval(t, 0) + // Proxy uses self-signed certificates in tests. + testenv.WithInsecureDevMode(t, true) + t.Run("Login", testDatabaseLogin) + t.Run("List", testListDatabase) + t.Run("FilterActiveDatabases", testFilterActiveDatabases) +} +// testDatabaseLogin tests "tsh db login" command and verifies "tsh db +// env/config" after login. +func testDatabaseLogin(t *testing.T) { + t.Parallel() alice, err := types.NewUser("alice@example.com") require.NoError(t, err) + alice.SetDatabaseUsers([]string{types.Wildcard}) + alice.SetDatabaseNames([]string{types.Wildcard}) alice.SetRoles([]string{"access"}) - - authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice), - withAuthConfig(func(cfg *servicecfg.AuthConfig) { - cfg.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - }), - withConfig(func(cfg *servicecfg.Config) { + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, alice) + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) // separate MySQL port with TLS routing. - cfg.Proxy.MySQLAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} - })) - makeTestDatabaseServer(t, authProcess, proxyProcess, - servicecfg.Database{ - Name: "postgres", - Protocol: defaults.ProtocolPostgres, - URI: "localhost:5432", - }, servicecfg.Database{ - Name: "mysql", - Protocol: defaults.ProtocolMySQL, - URI: "localhost:3306", - }, servicecfg.Database{ - Name: "cassandra", - Protocol: defaults.ProtocolCassandra, - URI: "localhost:9042", - }, servicecfg.Database{ - Name: "snowflake", - Protocol: defaults.ProtocolSnowflake, - URI: "localhost.snowflakecomputing.com", - }, servicecfg.Database{ - Name: "mongo", - Protocol: defaults.ProtocolMongoDB, - URI: "localhost:27017", - }, servicecfg.Database{ - Name: "mssql", - Protocol: defaults.ProtocolSQLServer, - URI: "localhost:1433", - }, servicecfg.Database{ - Name: "dynamodb", - Protocol: defaults.ProtocolDynamoDB, - URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. - AWS: servicecfg.DatabaseAWS{ - AccountID: "123456789012", - ExternalID: "123123123", - Region: "us-west-1", - }, - }) - - authServer := authProcess.GetAuthServer() - require.NotNil(t, authServer) - - proxyAddr, err := proxyProcess.ProxyWebAddr() - require.NoError(t, err) + // set the public address to be sure even on v2+, tsh clients will see the separate port. + mySQLAddr := localListenerAddr() + cfg.Proxy.MySQLAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: mySQLAddr} + cfg.Proxy.MySQLPublicAddrs = []utils.NetAddr{{AddrNetwork: "tcp", Addr: mySQLAddr}} + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + { + Name: "postgres-local", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + "env": "local", + }, + }, { + Name: "postgres-rds-us-west-1-123456789012", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + types.DiscoveredNameLabel: "postgres", + "region": "us-west-1", + "env": "prod", + }, + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + Region: "us-west-1", + RDS: servicecfg.DatabaseAWSRDS{ + InstanceID: "postgres", + }, + }, + }, { + Name: "postgres-rds-us-west-2-123456789012", + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: map[string]string{ + types.DiscoveredNameLabel: "postgres", + "region": "us-west-2", + "env": "prod", + }, + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + Region: "us-west-2", + RDS: servicecfg.DatabaseAWSRDS{ + InstanceID: "postgres", + }, + }, + }, { + Name: "mysql", + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + }, { + Name: "cassandra", + Protocol: defaults.ProtocolCassandra, + URI: "localhost:9042", + }, { + Name: "snowflake", + Protocol: defaults.ProtocolSnowflake, + URI: "localhost.snowflakecomputing.com", + }, { + Name: "mongo", + Protocol: defaults.ProtocolMongoDB, + URI: "localhost:27017", + }, { + Name: "mssql", + Protocol: defaults.ProtocolSQLServer, + URI: "localhost:1433", + }, { + Name: "dynamodb", + Protocol: defaults.ProtocolDynamoDB, + URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + ExternalID: "123123123", + Region: "us-west-1", + }, + }} + }), + ) + s.user = alice // Log into Teleport cluster. - err = Run(context.Background(), []string{ - "login", "--insecure", "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(), - }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { - cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) - return nil - })) - require.NoError(t, err) + tmpHomePath, _ := mustLogin(t, s) testCases := []struct { - databaseName string + // databaseName should be the full database name. + databaseName string + // dbSelectors can be any of db name, --labels, --query predicate, + // and defaults to be databaseName if not set. + dbSelectors []string expectCertsLen int expectKeysLen int expectErrForConfigCmd bool expectErrForEnvCmd bool }{ - { - databaseName: "postgres", - expectCertsLen: 1, - }, { databaseName: "mongo", expectCertsLen: 1, @@ -165,6 +203,24 @@ func TestDatabaseLogin(t *testing.T) { expectErrForConfigCmd: true, // "tsh db config" not supported for DynamoDB. expectErrForEnvCmd: true, // "tsh db env" not supported for DynamoDB. }, + { + databaseName: "postgres-local", + // select by labels alone. + dbSelectors: []string{"--labels", "env=local"}, + expectCertsLen: 1, + }, + { + databaseName: "postgres-rds-us-west-1-123456789012", + // select by query alone. + dbSelectors: []string{"--query", `labels.env=="prod" && labels.region == "us-west-1"`}, + expectCertsLen: 1, + }, + { + databaseName: "postgres-rds-us-west-2-123456789012", + // select by uniquely identifying prefix. + dbSelectors: []string{"postgres-rds-us-west-2"}, + expectCertsLen: 1, + }, } // Note: keystore currently races when multiple tsh clients work in the @@ -175,50 +231,83 @@ func TestDatabaseLogin(t *testing.T) { // Copying the profile dir is faster than sequential login for each database. for _, test := range testCases { test := test - t.Run(fmt.Sprintf("%v/%v", "tsh db login", test.databaseName), func(t *testing.T) { + t.Run(test.databaseName, func(t *testing.T) { t.Parallel() tmpHomePath := mustCloneTempDir(t, tmpHomePath) - err := Run(context.Background(), []string{ - "db", "login", "--db-user", "admin", test.databaseName, - }, setHomePath(tmpHomePath)) + selectors := test.dbSelectors + if len(selectors) == 0 { + selectors = []string{test.databaseName} + } + + // override the mysql/postgres config file paths to avoid parallel + // updates to the default location in the user home dir. + mySqlCnfPath := filepath.Join(tmpHomePath, ".my.cnf") + pgCnfPath := filepath.Join(tmpHomePath, ".pg_service.conf") + // all subsequent tsh commands need these options. + cliOpts := []cliOption{ + // set .tsh location to the temp dir for this test. + setHomePath(tmpHomePath), + setOverrideMySQLConfigPath(mySqlCnfPath), + setOverridePostgresConfigPath(pgCnfPath), + } + args := append([]string{ + "db", "login", "--db-user", "admin", + }, selectors...) + err := Run(context.Background(), args, cliOpts...) require.NoError(t, err) // Fetch the active profile. clientStore := client.NewFSClientStore(tmpHomePath) - profile, err := clientStore.ReadProfileStatus(proxyAddr.Host()) + profile, err := clientStore.ReadProfileStatus(s.root.Config.Proxy.WebAddr.String()) require.NoError(t, err) - require.Equal(t, alice.GetName(), profile.Username) + require.Equal(t, s.user.GetName(), profile.Username) // Verify certificates. + // grab the certs using the actual database name to verify certs. certs, keys, err := decodePEM(profile.DatabaseCertPathForCluster("", test.databaseName)) require.NoError(t, err) require.Equal(t, test.expectCertsLen, len(certs)) // don't use require.Len, because it spams PEM bytes on fail. require.Equal(t, test.expectKeysLen, len(keys)) // don't use require.Len, because it spams PEM bytes on fail. - t.Run(fmt.Sprintf("%v/%v", "tsh db config", test.databaseName), func(t *testing.T) { - t.Parallel() - err := Run(context.Background(), []string{ - "db", "config", test.databaseName, - }, setHomePath(tmpHomePath)) - - if test.expectErrForConfigCmd { - require.Error(t, err) - } else { - require.NoError(t, err) - } + t.Run("print info", func(t *testing.T) { + // organize these as parallel subtests in a group, so we can run + // them in parallel together before the logout test runs below. + t.Run("config", func(t *testing.T) { + t.Parallel() + args := append([]string{ + "db", "config", + }, selectors...) + err := Run(context.Background(), args, cliOpts...) + + if test.expectErrForConfigCmd { + require.Error(t, err) + require.NotContains(t, err.Error(), "matches multiple", "should not be ambiguity error") + } else { + require.NoError(t, err) + } + }) + t.Run("env", func(t *testing.T) { + t.Parallel() + args := append([]string{ + "db", "env", + }, selectors...) + err := Run(context.Background(), args, cliOpts...) + + if test.expectErrForEnvCmd { + require.Error(t, err) + require.NotContains(t, err.Error(), "matches multiple", "should not be ambiguity error") + } else { + require.NoError(t, err) + } + }) }) - t.Run(fmt.Sprintf("%v/%v", "tsh db env", test.databaseName), func(t *testing.T) { - t.Parallel() - err := Run(context.Background(), []string{ - "db", "env", test.databaseName, - }, setHomePath(tmpHomePath)) - - if test.expectErrForEnvCmd { - require.Error(t, err) - } else { - require.NoError(t, err) - } + t.Run("logout", func(t *testing.T) { + args := append([]string{ + "db", "logout", + }, selectors...) + err := Run(context.Background(), args, cliOpts...) + require.NoError(t, err) }) }) } @@ -334,22 +423,34 @@ func TestLocalProxyRequirement(t *testing.T) { } } -func TestListDatabase(t *testing.T) { - lib.SetInsecureDevMode(true) - defer lib.SetInsecureDevMode(false) - +func testListDatabase(t *testing.T) { + t.Parallel() + discoveredName := "root-postgres" + fullName := "root-postgres-rds-us-west-1-123456789012" s := newTestSuite(t, withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.StorageConfig.Params["poll_stream_period"] = 50 * time.Millisecond cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) cfg.Databases.Enabled = true cfg.Databases.Databases = []servicecfg.Database{{ - Name: "root-postgres", + Name: fullName, Protocol: defaults.ProtocolPostgres, URI: "localhost:5432", + StaticLabels: map[string]string{ + types.DiscoveredNameLabel: discoveredName, + }, + AWS: servicecfg.DatabaseAWS{ + AccountID: "123456789012", + Region: "us-west-1", + RDS: servicecfg.DatabaseAWSRDS{ + InstanceID: "root-postgres", + }, + }, }} }), withLeafCluster(), withLeafConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.StorageConfig.Params["poll_stream_period"] = 50 * time.Millisecond cfg.Databases.Enabled = true cfg.Databases.Databases = []servicecfg.Database{{ Name: "leaf-postgres", @@ -359,7 +460,7 @@ func TestListDatabase(t *testing.T) { }), ) - mustLoginSetEnv(t, s) + tshHome, _ := mustLogin(t, s) captureStdout := new(bytes.Buffer) err := Run(context.Background(), []string{ @@ -367,9 +468,34 @@ func TestListDatabase(t *testing.T) { "ls", "--insecure", "--debug", - }, setCopyStdout(captureStdout)) + }, setCopyStdout(captureStdout), setHomePath(tshHome)) + + require.NoError(t, err) + lines := strings.Split(captureStdout.String(), "\n") + require.Greater(t, len(lines), 2, + "there should be two lines of header followed by data rows") + require.True(t, + strings.HasPrefix(lines[2], discoveredName), + "non-verbose listing should print the discovered db name") + require.False(t, + strings.HasPrefix(lines[2], fullName), + "non-verbose listing should not print full db name") + + captureStdout.Reset() + err = Run(context.Background(), []string{ + "db", + "ls", + "--verbose", + "--insecure", + "--debug", + }, setCopyStdout(captureStdout), setHomePath(tshHome)) require.NoError(t, err) - require.Contains(t, captureStdout.String(), "root-postgres") + lines = strings.Split(captureStdout.String(), "\n") + require.Greater(t, len(lines), 2, + "there should be two lines of header followed by data rows") + require.True(t, + strings.HasPrefix(lines[2], fullName), + "verbose listing should print full db name") captureStdout.Reset() err = Run(context.Background(), []string{ @@ -379,7 +505,8 @@ func TestListDatabase(t *testing.T) { "leaf1", "--insecure", "--debug", - }, setCopyStdout(captureStdout)) + }, setCopyStdout(captureStdout), setHomePath(tshHome)) + require.NoError(t, err) require.Contains(t, captureStdout.String(), "leaf-postgres") } @@ -524,38 +651,6 @@ func TestDBInfoHasChanged(t *testing.T) { } } -func makeTestDatabaseServer(t *testing.T, auth *service.TeleportProcess, proxy *service.TeleportProcess, dbs ...servicecfg.Database) (db *service.TeleportProcess) { - // Proxy uses self-signed certificates in tests. - lib.SetInsecureDevMode(true) - - cfg := servicecfg.MakeDefaultConfig() - cfg.Hostname = "localhost" - cfg.DataDir = t.TempDir() - cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() - - proxyAddr, err := proxy.ProxyWebAddr() - require.NoError(t, err) - - cfg.SetAuthServerAddress(*proxyAddr) - - token, err := proxy.Config.Token() - require.NoError(t, err) - - cfg.SetToken(token) - cfg.SSH.Enabled = false - cfg.Auth.Enabled = false - cfg.Proxy.Enabled = false - cfg.Databases.Enabled = true - cfg.Databases.Databases = dbs - cfg.Log = utils.NewLoggerForTests() - - db = runTeleport(t, cfg) - - // Wait for all databases to register to avoid races. - waitForDatabases(t, auth, dbs) - return db -} - func waitForDatabases(t *testing.T, auth *service.TeleportProcess, dbs []servicecfg.Database) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -678,3 +773,184 @@ func TestFormatDatabaseConnectArgs(t *testing.T) { }) } } + +func testFilterActiveDatabases(t *testing.T) { + t.Parallel() + // setup some databases and "active" routes to test filtering + db1, route1 := makeDBConfigAndRoute("foobar", map[string]string{"env": "dev", "svc": "fooer"}) + db1AWS, route1AWS := makeDBConfigAndRoute("foobar-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) + db1Azure, route1Azure := makeDBConfigAndRoute("foobar-westus-11111", map[string]string{"env": "prod", "region": "westus"}) + db2, route2 := makeDBConfigAndRoute("bazqux", map[string]string{"env": "dev", "svc": "bazzer"}) + db2AWS, route2AWS := makeDBConfigAndRoute("bazqux-us-west-1-123456789012", map[string]string{"env": "prod", "region": "us-west-1"}) + db3, route3 := makeDBConfigAndRoute("some-unique-name", map[string]string{"env": "dev"}) + routes := []tlsca.RouteToDatabase{route1, route1AWS, route1Azure, route2, route2AWS, route3} + s := newTestSuite(t, + withRootConfigFunc(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + db1, db1AWS, db1Azure, db2, db2AWS, db3, + } + }), + ) + + // Log into Teleport cluster. + tmpHomePath, _ := mustLogin(t, s) + + tests := []struct { + name, + dbName, + labels, + query string + wantAPICall bool + wantRoutes []tlsca.RouteToDatabase + }{ + { + name: "by exact name", + dbName: route1.ServiceName, + wantAPICall: false, + wantRoutes: []tlsca.RouteToDatabase{route1}, + }, + { + name: "by name prefix", + dbName: "foo", + wantAPICall: false, + wantRoutes: []tlsca.RouteToDatabase{route1, route1AWS, route1Azure}, + }, + { + name: "by labels", + labels: "env=dev", + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1, route2, route3}, + }, + { + name: "by query", + query: `labels.env == "dev"`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1, route2, route3}, + }, + { + name: "by name prefix and labels", + dbName: "foo", + labels: "env=prod", + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1AWS, route1Azure}, + }, + { + name: "by name prefix and query", + dbName: "foo", + query: `labels.region == "us-west-1"`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1AWS}, + }, + { + name: "by labels and query", + labels: "env=dev", + query: `hasPrefix(name, "some-uniq")`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route3}, + }, + { + name: "by name prefix and labels and query", + dbName: "foo", + labels: "env=prod", + query: `labels.region == "westus"`, + wantAPICall: true, + wantRoutes: []tlsca.RouteToDatabase{route1Azure}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + cf := &CLIConf{ + Context: ctx, + HomePath: tmpHomePath, + DatabaseService: tt.dbName, + Labels: tt.labels, + PredicateExpression: tt.query, + } + tc, err := makeClient(cf) + require.NoError(t, err) + routes, dbs, err := filterActiveDatabases(ctx, tc, routes) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tt.wantRoutes, routes)) + if tt.wantAPICall { + require.Equal(t, len(routes), len(dbs), + "returned routes should have corresponding types.Databases") + return + } + require.Zero(t, len(dbs), "unexpected API call to ListDatabases") + }) + } +} + +func TestResourceSelectorsFormatting(t *testing.T) { + tests := []struct { + testName string + selectors resourceSelectors + want string + }{ + { + testName: "no selectors", + selectors: resourceSelectors{ + kind: "database", + }, + want: "database", + }, + { + testName: "by name", + selectors: resourceSelectors{ + kind: "database", + name: "foo", + }, + want: `database "foo"`, + }, + { + testName: "by labels", + selectors: resourceSelectors{ + kind: "database", + labels: "env=dev,region=us-west-1", + }, + want: `database with labels "env=dev,region=us-west-1"`, + }, + { + testName: "by predicate", + selectors: resourceSelectors{ + kind: "database", + query: `labels["env"]=="dev" && labels.region == "us-west-1"`, + }, + want: `database with query (labels["env"]=="dev" && labels.region == "us-west-1")`, + }, + { + testName: "by name and labels and predicate", + selectors: resourceSelectors{ + kind: "app", + name: "foo", + labels: "env=dev,region=us-west-1", + query: `labels["env"]=="dev" && labels.region == "us-west-1"`, + }, + want: `app "foo" with labels "env=dev,region=us-west-1" with query (labels["env"]=="dev" && labels.region == "us-west-1")`, + }, + } + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + require.Equal(t, tt.want, fmt.Sprintf("%v", tt.selectors)) + }) + } +} + +// makeDBConfigAndRoute is a helper func that makes a db config and +// corresponding cert encoded route to that db - protocol etc not important. +func makeDBConfigAndRoute(name string, staticLabels map[string]string) (servicecfg.Database, tlsca.RouteToDatabase) { + db := servicecfg.Database{ + Name: name, + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + StaticLabels: staticLabels, + } + route := tlsca.RouteToDatabase{ServiceName: name} + return db, route +} diff --git a/tool/tsh/kube.go b/tool/tsh/kube.go index 6f5932beca8c0..0dfa2e11eb50b 100644 --- a/tool/tsh/kube.go +++ b/tool/tsh/kube.go @@ -968,7 +968,7 @@ func formatKubeLabels(cluster types.KubeCluster) string { func (c *kubeLSCommand) run(cf *CLIConf) error { cf.SearchKeywords = c.searchKeywords - cf.UserHost = c.labels + cf.Labels = c.labels cf.PredicateExpression = c.predicateExpr cf.SiteName = c.siteName diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 3ea77e9399ee8..a5fe2a676cfb0 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -271,6 +271,11 @@ type CLIConf struct { // PredicateExpression defines boolean conditions that will be matched against the resource. PredicateExpression string + // Labels is used to hold labels passed via --labels=k1=v2,k2=v2,,, flag for resource filtering. + // explicitly passed --labels overrides user@labels positional arg form. + // NOTE: no command currently supports both, try to keep it that way. + Labels string + // NoRemoteExec will not execute a remote command after connecting to a host, // will block instead. Useful when port forwarding. Equivalent of -N for OpenSSH. NoRemoteExec bool @@ -335,6 +340,16 @@ type CLIConf struct { // mockHeadlessLogin used in tests to override Headless login handler in teleport client. mockHeadlessLogin client.SSHLoginFunc + // overrideMySQLOptionFilePath overrides the MySQL option file path to use. + // Useful in parallel tests so they don't all use the default path in the + // user home dir. + overrideMySQLOptionFilePath string + + // overridePostgresServiceFilePath overrides the Postgres service file path. + // Useful in parallel tests so they don't all use the default path in the + // user home dir. + overridePostgresServiceFilePath string + // HomePath is where tsh stores profiles HomePath string @@ -722,7 +737,7 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { lsApps.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) lsApps.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) lsApps.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) - lsApps.Arg("labels", labelHelp).StringVar(&cf.UserHost) + lsApps.Arg("labels", labelHelp).StringVar(&cf.Labels) lsApps.Flag("all", "List apps from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll) appLogin := apps.Command("login", "Retrieve short-lived certificate for an app.") appLogin.Arg("app", "App name to retrieve credentials for. Can be obtained from `tsh apps ls` output.").Required().StringVar(&cf.AppName) @@ -757,7 +772,8 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { proxySSH.Arg("[user@]host", "Remote hostname and the login to use").Required().StringVar(&cf.UserHost) proxySSH.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) proxyDB := proxy.Command("db", "Start local TLS proxy for database connections when using Teleport in single-port mode.") - proxyDB.Arg("db", "The name of the database to start local proxy for").Required().StringVar(&cf.DatabaseService) + // don't require positional argument, user can select with --labels/--query alone. + proxyDB.Arg("db", "The name of the database to start local proxy for").StringVar(&cf.DatabaseService) proxyDB.Flag("port", "Specifies the source port used by proxy db listener").Short('p').StringVar(&cf.LocalProxyPort) // --cert-file and --key-file are deprecated in favor of --tunnel flag. proxyDB.Flag("cert-file", "Certificate file for proxy client TLS configuration").Hidden().StringVar(&cf.LocalProxyCertFile) @@ -766,6 +782,8 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { proxyDB.Flag("db-user", "Optional database user to log in as.").StringVar(&cf.DatabaseUser) proxyDB.Flag("db-name", "Optional database name to log in to.").StringVar(&cf.DatabaseName) proxyDB.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) + proxyDB.Flag("labels", labelHelp).StringVar(&cf.Labels) + proxyDB.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) proxyApp := proxy.Command("app", "Start local TLS proxy for app connection when using Teleport in single-port mode.") proxyApp.Arg("app", "The name of the application to start local proxy for").Required().StringVar(&cf.AppName) @@ -800,20 +818,29 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { dbList.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbList.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) dbList.Flag("all", "List databases from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll) - dbList.Arg("labels", labelHelp).StringVar(&cf.UserHost) + dbList.Arg("labels", labelHelp).StringVar(&cf.Labels) dbLogin := db.Command("login", "Retrieve credentials for a database.") - dbLogin.Arg("db", "Database to retrieve credentials for. Can be obtained from 'tsh db ls' output.").Required().StringVar(&cf.DatabaseService) + // don't require positional argument, user can select with --labels/--query alone. + dbLogin.Arg("db", "Database to retrieve credentials for. Can be obtained from 'tsh db ls' output.").StringVar(&cf.DatabaseService) + dbLogin.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbLogin.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbLogin.Flag("db-user", "Optional database user to configure as default.").StringVar(&cf.DatabaseUser) dbLogin.Flag("db-name", "Optional database name to configure as default.").StringVar(&cf.DatabaseName) dbLogout := db.Command("logout", "Remove database credentials.") dbLogout.Arg("db", "Database to remove credentials for.").StringVar(&cf.DatabaseService) + dbLogout.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbLogout.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbEnv := db.Command("env", "Print environment variables for the configured database.") - dbEnv.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) dbEnv.Arg("db", "Print environment for the specified database").StringVar(&cf.DatabaseService) + dbEnv.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) + dbEnv.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbEnv.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) // --db flag is deprecated in favor of positional argument for consistency with other commands. dbEnv.Flag("db", "Print environment for the specified database.").Hidden().StringVar(&cf.DatabaseService) dbConfig := db.Command("config", "Print database connection information. Useful when configuring GUI clients.") dbConfig.Arg("db", "Print information for the specified database.").StringVar(&cf.DatabaseService) + dbConfig.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbConfig.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) // --db flag is deprecated in favor of positional argument for consistency with other commands. dbConfig.Flag("db", "Print information for the specified database.").Hidden().StringVar(&cf.DatabaseService) dbConfig.Flag("format", fmt.Sprintf("Print format: %q to print in table format (default), %q to print connect command, %q or %q to print in JSON or YAML.", @@ -822,6 +849,8 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { dbConnect.Arg("db", "Database service name to connect to.").StringVar(&cf.DatabaseService) dbConnect.Flag("db-user", "Optional database user to log in as.").StringVar(&cf.DatabaseUser) dbConnect.Flag("db-name", "Optional database name to log in to.").StringVar(&cf.DatabaseName) + dbConnect.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) // join join := app.Command("join", "Join the active SSH or Kubernetes session.") @@ -853,7 +882,7 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { ls.Flag("format", defaults.FormatFlagDescription( teleport.Text, teleport.JSON, teleport.YAML, teleport.Names, )).Short('f').Default(teleport.Text).EnumVar(&cf.Format, teleport.Text, teleport.JSON, teleport.YAML, teleport.Names) - ls.Arg("labels", labelHelp).StringVar(&cf.UserHost) + ls.Arg("labels", labelHelp).StringVar(&cf.Labels) ls.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) ls.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) ls.Flag("all", "List nodes from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll) @@ -968,7 +997,7 @@ func Run(ctx context.Context, args []string, opts ...cliOption) error { ).Required().EnumVar(&cf.ResourceKind, types.RequestableResourceKinds...) reqSearch.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) reqSearch.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) - reqSearch.Flag("labels", labelHelp).StringVar(&cf.UserHost) + reqSearch.Flag("labels", labelHelp).StringVar(&cf.Labels) reqSearch.Flag("kube-cluster", "Kubernetes Cluster to search for Pods").StringVar(&cf.KubernetesCluster) reqSearch.Flag("kube-namespace", "Kubernetes Namespace to search for Pods").Default(corev1.NamespaceDefault).StringVar(&cf.kubeNamespace) reqSearch.Flag("all-kube-namespaces", "Search Pods in every namespace").BoolVar(&cf.kubeAllNamespaces) @@ -2758,12 +2787,25 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec return fmt.Sprintf("%v, except: %v", dbUsers.Allowed, dbUsers.Denied) } +func getDiscoveredName(r types.ResourceWithLabels) (string, bool) { + name, ok := r.GetAllLabels()[types.DiscoveredNameLabel] + return name, ok +} + func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, active []tlsca.RouteToDatabase, accessChecker services.AccessChecker, verbose bool) []string { name := database.GetName() + printName := name + if d, ok := getDiscoveredName(database); ok && !verbose && d != name { + printName = d + } var connect string for _, a := range active { if a.ServiceName == name { - name = formatActiveDB(a) + a.ServiceName = printName + // format the db name with the print name + printName = formatActiveDB(a) + // then revert it for connect string + a.ServiceName = name switch a.Protocol { case defaults.ProtocolDynamoDB: // DynamoDB does not support "tsh db connect", so print the proxy command instead. @@ -2771,6 +2813,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, default: connect = formatDatabaseConnectCommand(clusterFlag, a) } + break } } @@ -2781,7 +2824,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, if verbose { row = append(row, - name, + printName, database.GetDescription(), database.GetProtocol(), database.GetType(), @@ -2792,7 +2835,7 @@ func getDatabaseRow(proxy, cluster, clusterFlag string, database types.Database, ) } else { row = append(row, - name, + printName, database.GetDescription(), formatUsersForDB(database, accessChecker), formatDatabaseLabels(database), @@ -2849,8 +2892,9 @@ func printDatabasesWithClusters(clusterFlag string, dbListings []databaseListing func formatDatabaseLabels(database types.Database) string { labels := database.GetAllLabels() - // Hide the origin label unless printing verbose table. + // Hide the origin and discovered-name labels unless printing verbose table. delete(labels, types.OriginLabel) + delete(labels, types.DiscoveredNameLabel) return sortedLabels(labels) } @@ -3369,6 +3413,15 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err } } } + + // explicitly passed --labels overrides user@labels positional arg form. + if cf.Labels != "" { + labels, err = client.ParseLabelSpec(cf.Labels) + if err != nil { + return nil, trace.Wrap(err) + } + } + fPorts, err := client.ParsePortForwardSpec(cf.LocalForwardPorts) if err != nil { return nil, trace.Wrap(err) @@ -3618,6 +3671,10 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err c.MockSSOLogin = cf.mockSSOLogin c.MockHeadlessLogin = cf.mockHeadlessLogin + // pass along MySQL/Postgres path overrides (only used in tests). + c.OverrideMySQLOptionFilePath = cf.overrideMySQLOptionFilePath + c.OverridePostgresServiceFilePath = cf.overridePostgresServiceFilePath + // Set tsh home directory c.HomePath = cf.HomePath diff --git a/tool/tsh/tsh_helper_test.go b/tool/tsh/tsh_helper_test.go index b42717e0b61fd..6c21cf7392657 100644 --- a/tool/tsh/tsh_helper_test.go +++ b/tool/tsh/tsh_helper_test.go @@ -91,12 +91,13 @@ func (s *suite) setupRootCluster(t *testing.T, options testSuiteOptions) { cfg.Proxy.DisableWebInterface = true cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ StaticTokens: []types.ProvisionTokenV1{{ - Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleNode, types.RoleTrustedCluster}, + Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp}, Expires: time.Now().Add(time.Minute), Token: staticToken, }}, }) require.NoError(t, err) + cfg.SetToken(staticToken) user, err := user.Current() require.NoError(t, err) @@ -134,7 +135,6 @@ func (s *suite) setupRootCluster(t *testing.T, options testSuiteOptions) { } s.root = runTeleport(t, cfg) - t.Cleanup(func() { require.NoError(t, s.root.Close()) }) } func (s *suite) setupLeafCluster(t *testing.T, options testSuiteOptions) { @@ -182,6 +182,15 @@ func (s *suite) setupLeafCluster(t *testing.T, options testSuiteOptions) { require.NoError(t, err) cfg.Proxy.DisableWebInterface = true + cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ + StaticTokens: []types.ProvisionTokenV1{{ + Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp}, + Expires: time.Now().Add(time.Minute), + Token: staticToken, + }}, + }) + require.NoError(t, err) + cfg.SetToken(staticToken) sshLoginRole, err := types.NewRole("ssh-login", types.RoleSpecV6{ Allow: types.RoleConditions{ Logins: []string{user.Username}, diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index d2de83c406058..53cbb817a7c63 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -3367,6 +3367,20 @@ func setHomePath(path string) cliOption { } } +func setOverrideMySQLConfigPath(path string) cliOption { + return func(cf *CLIConf) error { + cf.overrideMySQLOptionFilePath = path + return nil + } +} + +func setOverridePostgresConfigPath(path string) cliOption { + return func(cf *CLIConf) error { + cf.overridePostgresServiceFilePath = path + return nil + } +} + func setKubeConfigPath(path string) cliOption { return func(cf *CLIConf) error { cf.kubeConfigPath = path