diff --git a/lib/client/db/dbcmd/dbcmd.go b/lib/client/db/dbcmd/dbcmd.go index 63954d3ff0136..da7d46b339960 100644 --- a/lib/client/db/dbcmd/dbcmd.go +++ b/lib/client/db/dbcmd/dbcmd.go @@ -116,7 +116,7 @@ type CLICommandBuilder struct { } func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus, - db *tlsca.RouteToDatabase, rootClusterName string, opts ...ConnectCommandFunc, + db tlsca.RouteToDatabase, rootClusterName string, opts ...ConnectCommandFunc, ) *CLICommandBuilder { var options connectionCommandOpts for _, opt := range opts { @@ -124,7 +124,7 @@ func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus, } // In TLS routing mode a local proxy is started on demand so connect to it. - host, port := tc.DatabaseProxyHostPort(*db) + host, port := tc.DatabaseProxyHostPort(db) if options.localProxyPort != 0 && options.localProxyHost != "" { host = options.localProxyHost port = options.localProxyPort @@ -141,7 +141,7 @@ func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus, return &CLICommandBuilder{ tc: tc, profile: profile, - db: db, + db: &db, host: host, port: port, options: options, diff --git a/lib/client/db/dbcmd/dbcmd_test.go b/lib/client/db/dbcmd/dbcmd_test.go index 8d7f2c019247f..ab872a7e63108 100644 --- a/lib/client/db/dbcmd/dbcmd_test.go +++ b/lib/client/db/dbcmd/dbcmd_test.go @@ -598,7 +598,7 @@ func TestCLICommandBuilderGetConnectCommand(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - database := &tlsca.RouteToDatabase{ + database := tlsca.RouteToDatabase{ Protocol: tt.dbProtocol, Database: tt.databaseName, Username: "myUser", @@ -761,7 +761,7 @@ func TestCLICommandBuilderGetConnectCommandAlternatives(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - database := &tlsca.RouteToDatabase{ + database := tlsca.RouteToDatabase{ Protocol: tt.dbProtocol, Database: tt.databaseName, Username: "myUser", @@ -848,7 +848,7 @@ func TestConvertCommandError(t *testing.T) { t.Run(tt.desc, func(t *testing.T) { t.Parallel() - database := &tlsca.RouteToDatabase{ + database := tlsca.RouteToDatabase{ Protocol: tt.dbProtocol, Database: "DBName", Username: "myUser", diff --git a/lib/teleterm/clusters/dbcmd_cli_command_provider.go b/lib/teleterm/clusters/dbcmd_cli_command_provider.go index 63d9b29fdddba..c2a4a92b6f1fb 100644 --- a/lib/teleterm/clusters/dbcmd_cli_command_provider.go +++ b/lib/teleterm/clusters/dbcmd_cli_command_provider.go @@ -55,7 +55,7 @@ func (d DbcmdCLICommandProvider) GetCommand(gateway *gateway.Gateway) (*exec.Cmd Database: gateway.TargetSubresourceName(), } - cmd, err := dbcmd.NewCmdBuilder(cluster.clusterClient, &cluster.status, &routeToDb, + cmd, err := dbcmd.NewCmdBuilder(cluster.clusterClient, &cluster.status, routeToDb, // TODO(ravicious): Pass the root cluster name here. cluster.Name returns leaf name for leaf // clusters. // diff --git a/tool/tsh/db.go b/tool/tsh/db.go index fc482c38ed270..33a706b2c100c 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -253,77 +253,45 @@ func onDatabaseLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - database, err := getDatabase(cf, tc, cf.DatabaseService) - if err != nil { - return trace.Wrap(err) - } - route := tlsca.RouteToDatabase{ + dbInfo, err := newDatabaseInfo(cf, tc, tlsca.RouteToDatabase{ ServiceName: cf.DatabaseService, - Protocol: database.GetProtocol(), Username: cf.DatabaseUser, Database: cf.DatabaseName, + }) + if err != nil { + return trace.Wrap(err) } - if err := databaseLogin(cf, tc, route); err != nil { + database, err := dbInfo.GetDatabase(cf, tc) + if err != nil { + return trace.Wrap(err) + } + + if err := databaseLogin(cf, tc, dbInfo); err != nil { return trace.Wrap(err) } // Print after-login message. templateData := map[string]string{ - "name": route.ServiceName, + "name": dbInfo.ServiceName, } // DynamoDB does not support a connect command, so don't try to print one. if database.GetProtocol() != defaults.ProtocolDynamoDB { - templateData["connectCommand"] = utils.Color(utils.Yellow, formatDatabaseConnectCommand(cf.SiteName, route)) + templateData["connectCommand"] = utils.Color(utils.Yellow, formatDatabaseConnectCommand(cf.SiteName, dbInfo.RouteToDatabase)) } - requires := getDBLocalProxyRequirement(tc, &route) + requires := getDBLocalProxyRequirement(tc, dbInfo.RouteToDatabase) if requires.localProxy { - templateData["proxyCommand"] = utils.Color(utils.Yellow, formatDatabaseProxyCommand(cf.SiteName, route)) + templateData["proxyCommand"] = utils.Color(utils.Yellow, formatDatabaseProxyCommand(cf.SiteName, dbInfo.RouteToDatabase)) } else { - templateData["configCommand"] = utils.Color(utils.Yellow, formatDatabaseConfigCommand(cf.SiteName, route)) + templateData["configCommand"] = utils.Color(utils.Yellow, formatDatabaseConfigCommand(cf.SiteName, dbInfo.RouteToDatabase)) } return trace.Wrap(dbConnectTemplate.Execute(cf.Stdout(), templateData)) } -// checkAndSetDBRouteDefaults checks the database route and sets defaults for certificate generation. -func checkAndSetDBRouteDefaults(r *tlsca.RouteToDatabase) error { - // When generating certificate for MongoDB access, database username must - // be encoded into it. This is required to be able to tell which database - // user to authenticate the connection as Elasticsearch needs database username too. - if r.Username == "" { - switch r.Protocol { - case defaults.ProtocolMongoDB, defaults.ProtocolElasticsearch, defaults.ProtocolOracle, defaults.ProtocolOpenSearch: - return trace.BadParameter("please provide the database user name using the --db-user flag") - case defaults.ProtocolRedis: - // Default to "default" in the same way as Redis does. We need the username to check access on our side. - // ref: https://redis.io/commands/auth - r.Username = defaults.DefaultRedisUsername - } - } - if r.Database != "" { - switch r.Protocol { - case defaults.ProtocolDynamoDB: - log.Warnf("Database %v protocol %v does not support --db-name flag, ignoring --db-name=%v", - r.ServiceName, defaults.ReadableDatabaseProtocol(r.Protocol), r.Database) - r.Database = "" - } - } else { - switch r.Protocol { - // Always require db-name for Oracle Protocol. - case defaults.ProtocolOracle: - return trace.BadParameter("please provide the database name using the --db-name flag") - } - } - return nil -} - -func databaseLogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase) error { - log.Debugf("Fetching database access certificate for %s on cluster %v.", route, tc.SiteName) - if err := checkAndSetDBRouteDefaults(&route); err != nil { - return trace.Wrap(err) - } +func databaseLogin(cf *CLIConf, tc *client.TeleportClient, dbInfo *databaseInfo) error { + log.Debugf("Fetching database access certificate for %s on cluster %v.", dbInfo.RouteToDatabase, tc.SiteName) profile, err := tc.ProfileStatus() if err != nil { @@ -340,10 +308,10 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDa key, err = tc.IssueUserCertsWithMFA(cf.Context, client.ReissueParams{ RouteToCluster: tc.SiteName, RouteToDatabase: proto.RouteToDatabase{ - ServiceName: route.ServiceName, - Protocol: route.Protocol, - Username: route.Username, - Database: route.Database, + ServiceName: dbInfo.ServiceName, + Protocol: dbInfo.Protocol, + Username: dbInfo.Username, + Database: dbInfo.Database, }, AccessRequests: profile.ActiveRequests.AccessRequests, }, nil /*applyOpts*/) @@ -356,11 +324,11 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDa } } - if route.Protocol == defaults.ProtocolOracle { + if dbInfo.Protocol == defaults.ProtocolOracle { if err := generateDBLocalProxyCert(key, profile); err != nil { return trace.Wrap(err) } - err = oracle.GenerateClientConfiguration(key, route, profile) + err = oracle.GenerateClientConfiguration(key, dbInfo.RouteToDatabase, profile) if err != nil { return trace.Wrap(err) } @@ -372,7 +340,7 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDa return trace.Wrap(err) } // Update the database-specific connection profile file. - err = dbprofile.Add(cf.Context, tc, route, *profile) + err = dbprofile.Add(cf.Context, tc, dbInfo.RouteToDatabase, *profile) return trace.Wrap(err) } @@ -454,11 +422,11 @@ func onDatabaseEnv(cf *CLIConf) error { } if !dbprofile.IsSupported(*database) { - return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, database)) + return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, *database)) } - requires := getDBLocalProxyRequirement(tc, database) + requires := getDBLocalProxyRequirement(tc, *database) if requires.localProxy { - return trace.BadParameter(formatDbCmdUnsupported(cf, database, requires.localProxyReasons...)) + return trace.BadParameter(formatDbCmdUnsupported(cf, *database, requires.localProxyReasons...)) } env, err := dbprofile.Env(tc, *database) @@ -511,12 +479,12 @@ func onDatabaseConfig(cf *CLIConf) error { return trace.Wrap(err) } - requires := getDBLocalProxyRequirement(tc, database) + requires := getDBLocalProxyRequirement(tc, *database) // "tsh db config" prints out instructions for native clients to connect to // the remote proxy directly. Return errors here when direct connection // does NOT work (e.g. when ALPN local proxy is required). if requires.localProxy { - msg := formatDbCmdUnsupported(cf, database, requires.localProxyReasons...) + msg := formatDbCmdUnsupported(cf, *database, requires.localProxyReasons...) return trace.BadParameter(msg) } @@ -529,7 +497,7 @@ func onDatabaseConfig(cf *CLIConf) error { format := strings.ToLower(cf.Format) switch format { case dbFormatCommand: - cmd, err := dbcmd.NewCmdBuilder(tc, profile, database, rootCluster, + cmd, err := dbcmd.NewCmdBuilder(tc, profile, *database, rootCluster, dbcmd.WithPrintFormat(), dbcmd.WithLogger(log), ).GetConnectCommand() @@ -593,7 +561,7 @@ func serializeDatabaseConfig(configInfo *dbConfigInfo, format string) (string, e // command. func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, - route *tlsca.RouteToDatabase, db types.Database, rootClusterName string, + dbInfo *databaseInfo, rootClusterName string, requires *dbLocalProxyRequirement, ) ([]dbcmd.ConnectCommandFunc, error) { if !requires.localProxy { @@ -605,7 +573,7 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, log.Debugf("Starting local proxy because: %v", strings.Join(requires.localProxyReasons, ", ")) } - listener, err := createLocalProxyListener("localhost:0", route, profile) + listener, err := createLocalProxyListener("localhost:0", dbInfo.RouteToDatabase, profile) if err != nil { return nil, trace.Wrap(err) } @@ -614,8 +582,7 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, cf: cf, tc: tc, profile: profile, - route: *route, - database: db, + dbInfo: dbInfo, autoReissueCerts: requires.tunnel, tunnel: requires.tunnel, }) @@ -655,11 +622,10 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, // localProxyConfig is an argument pack used in prepareLocalProxyOptions(). type localProxyConfig struct { - cf *CLIConf - tc *client.TeleportClient - profile *client.ProfileStatus - route tlsca.RouteToDatabase - database types.Database + cf *CLIConf + tc *client.TeleportClient + profile *client.ProfileStatus + dbInfo *databaseInfo // autoReissueCerts indicates whether a cert auto reissuer should be used // for the local proxy to keep certificates valid. // - when `tsh db connect` needs to tunnel it will set this field. @@ -669,7 +635,7 @@ type localProxyConfig struct { tunnel bool } -func createLocalProxyListener(addr string, route *tlsca.RouteToDatabase, profile *client.ProfileStatus) (net.Listener, error) { +func createLocalProxyListener(addr string, route tlsca.RouteToDatabase, profile *client.ProfileStatus) (net.Listener, error) { if route.Protocol == defaults.ProtocolOracle { localCert, err := tls.LoadX509KeyPair( profile.DatabaseLocalCAPath(), @@ -690,22 +656,18 @@ func createLocalProxyListener(addr string, route *tlsca.RouteToDatabase, profile // prepareLocalProxyOptions created localProxyOpts needed to create local proxy from localProxyConfig. func prepareLocalProxyOptions(arg *localProxyConfig) ([]alpnproxy.LocalProxyConfigOpt, error) { - if err := checkAndSetDBRouteDefaults(&arg.route); err != nil { - return nil, trace.Wrap(err) - } - opts := []alpnproxy.LocalProxyConfigOpt{ - alpnproxy.WithDatabaseProtocol(arg.route.Protocol), + alpnproxy.WithDatabaseProtocol(arg.dbInfo.Protocol), alpnproxy.WithClusterCAsIfConnUpgrade(arg.cf.Context, arg.tc.RootClusterCACertPool), } - if !arg.tunnel && arg.route.Protocol == defaults.ProtocolPostgres { + if !arg.tunnel && arg.dbInfo.Protocol == defaults.ProtocolPostgres { opts = append(opts, alpnproxy.WithCheckCertsNeeded()) } // load certs if local proxy needs to be able to tunnel. // certs are needed for non-tunnel postgres cancel requests. - if arg.tunnel || arg.route.Protocol == defaults.ProtocolPostgres { + if arg.tunnel || arg.dbInfo.Protocol == defaults.ProtocolPostgres { certs, err := getDBLocalProxyCerts(arg) if err != nil { return nil, trace.Wrap(err) @@ -714,20 +676,16 @@ func prepareLocalProxyOptions(arg *localProxyConfig) ([]alpnproxy.LocalProxyConf } if arg.autoReissueCerts { - opts = append(opts, alpnproxy.WithMiddleware(client.NewDBCertChecker(arg.tc, arg.route, nil))) + opts = append(opts, alpnproxy.WithMiddleware(client.NewDBCertChecker(arg.tc, arg.dbInfo.RouteToDatabase, nil))) } // To set correct MySQL server version DB proxy needs additional protocol. - if !arg.tunnel && arg.route.Protocol == defaults.ProtocolMySQL { - if arg.database == nil { - var err error - arg.database, err = getDatabase(arg.cf, arg.tc, arg.route.ServiceName) - if err != nil { - return nil, trace.Wrap(err) - } + if !arg.tunnel && arg.dbInfo.Protocol == defaults.ProtocolMySQL { + db, err := arg.dbInfo.GetDatabase(arg.cf, arg.tc) + if err != nil { + return nil, trace.Wrap(err) } - - opts = append(opts, alpnproxy.WithMySQLVersionProto(arg.database)) + opts = append(opts, alpnproxy.WithMySQLVersionProto(db)) } return opts, nil } @@ -740,7 +698,7 @@ func getDBLocalProxyCerts(arg *localProxyConfig) ([]tls.Certificate, error) { return getUserSpecifiedLocalProxyCerts(arg) } // if neither --cert-file nor --key-file are specified, load db cert from client store. - cert, err := loadDBCertificate(arg.tc, arg.route.ServiceName) + cert, err := loadDBCertificate(arg.tc, arg.dbInfo.ServiceName) if err != nil { if arg.autoReissueCerts { // If using a reissuer, just return nil certs and let the reissuer @@ -779,16 +737,16 @@ func onDatabaseConnect(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - route, database, err := getDatabaseInfo(cf, tc) + dbInfo, err := getDatabaseInfo(cf, tc) if err != nil { return trace.Wrap(err) } - if route.Protocol == defaults.ProtocolDynamoDB { - return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, route)) + if dbInfo.Protocol == defaults.ProtocolDynamoDB { + return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, dbInfo.RouteToDatabase)) } - requires := getDBLocalProxyRequirement(tc, route, withConnectRequirements(cf.Context, tc, route)) - if err := maybeDatabaseLogin(cf, tc, profile, route, requires); err != nil { + requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase) + if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil { return trace.Wrap(err) } @@ -801,17 +759,17 @@ func onDatabaseConnect(cf *CLIConf) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - opts, err := maybeStartLocalProxy(ctx, cf, tc, profile, route, database, rootClusterName, requires) + opts, err := maybeStartLocalProxy(ctx, cf, tc, profile, dbInfo, rootClusterName, requires) if err != nil { return trace.Wrap(err) } opts = append(opts, dbcmd.WithLogger(log)) - if opts, err = maybeAddDBUserPassword(database, opts); err != nil { + if opts, err = maybeAddDBUserPassword(cf, tc, dbInfo, opts); err != nil { return trace.Wrap(err) } - bb := dbcmd.NewCmdBuilder(tc, profile, route, rootClusterName, opts...) + bb := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootClusterName, opts...) cmd, err := bb.GetConnectCommand() if err != nil { return trace.Wrap(err) @@ -834,59 +792,48 @@ func onDatabaseConnect(cf *CLIConf) error { return nil } -// getDatabaseInfo fetches information about the database from tsh profile is DB is active in profile. Otherwise, -// the ListDatabases endpoint is called. -func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToDatabase, types.Database, error) { - database, err := pickActiveDatabase(cf) - if err == nil { - switch database.Protocol { - case defaults.ProtocolCassandra: - // Cassandra CLI connection require database resource to determine - // if the target database is AWS hosted in order to skip the password prompt. - default: - return database, nil, nil - } - } - if err != nil && !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) - } - - dbService := cf.DatabaseService - username := cf.DatabaseUser - databaseName := cf.DatabaseName - if database != nil { - if dbService == "" { - dbService = database.ServiceName - } - if username == "" { - username = database.Username - } - if databaseName == "" { - databaseName = database.Database - } - } - - db, err := getDatabase(cf, tc, dbService) - if err != nil { - return nil, nil, trace.Wrap(err) +// getDatabaseInfo fetches information about the database from tsh profile if DB +// is active in profile. Otherwise, the ListDatabases endpoint is called. +func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) { + if route, err := pickActiveDatabase(cf); err == nil { + return newDatabaseInfo(cf, tc, *route) + } else if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) } + return newDatabaseInfo(cf, tc, tlsca.RouteToDatabase{ + ServiceName: cf.DatabaseService, + Username: cf.DatabaseUser, + Database: cf.DatabaseName, + }) +} - return &tlsca.RouteToDatabase{ - ServiceName: db.GetName(), - Protocol: db.GetProtocol(), - Username: username, - Database: databaseName, - }, db, nil +// databaseInfo wraps a RouteToDatabase and the corresponding database. +// Its purpose is to prevent repeated fetches of the same database, by lazily +// fetching and caching the database for use as needed. +type databaseInfo struct { + tlsca.RouteToDatabase + // database corresponds to the db route and may be nil, so use GetDatabase + // instead of accessing it directly. + database types.Database + mu sync.Mutex } -func getDatabase(cf *CLIConf, tc *client.TeleportClient, dbName string) (types.Database, error) { +// GetDatabase returns the cached database or fetches it using the db route and +// caches the result. +func (d *databaseInfo) GetDatabase(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { + d.mu.Lock() + defer d.mu.Unlock() + if d.database != nil { + return d.database, nil + } var databases []types.Database + // holding mutex across the api call to avoid multiple redundant api calls. err := client.RetryWithRelogin(cf.Context, tc, func() error { var err error databases, err = tc.ListDatabases(cf.Context, &proto.ListResourcesRequest{ Namespace: tc.Namespace, ResourceType: types.KindDatabaseServer, - PredicateExpression: fmt.Sprintf(`name == "%s"`, dbName), + PredicateExpression: fmt.Sprintf(`name == "%s"`, d.ServiceName), }) return trace.Wrap(err) }) @@ -895,12 +842,61 @@ func getDatabase(cf *CLIConf, tc *client.TeleportClient, dbName string) (types.D } if len(databases) == 0 { return nil, trace.NotFound( - "database %q not found, use '%v' to see registered databases", dbName, formatDatabaseListCommand(cf.SiteName)) + "database %q not found, use '%v' to see registered databases", + d.ServiceName, formatDatabaseListCommand(cf.SiteName)) + } + d.database = databases[0] + return d.database, nil +} + +// newDatabaseInfo makes a new databaseInfo from the given route to the db. +// It checks the route and sets defaults as needed for protocol, db user, or db +// name. If the remote database is needed for setting a default, it is retrieved +// by calling ListDatabases API and cached. +func newDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase) (*databaseInfo, error) { + dbInfo := databaseInfo{RouteToDatabase: route} + if dbInfo.ServiceName == "" { + return nil, trace.BadParameter("missing database service name") + } + if dbInfo.Protocol == "" { + db, err := dbInfo.GetDatabase(cf, tc) + if err != nil { + return nil, trace.Wrap(err) + } + dbInfo.Protocol = db.GetProtocol() + } + if dbInfo.Username == "" { + switch dbInfo.Protocol { + // When generating certificate for MongoDB access, database username must + // be encoded into it. This is required to be able to tell which database + // user to authenticate the connection as Elasticsearch needs database username too. + case defaults.ProtocolMongoDB, defaults.ProtocolElasticsearch, defaults.ProtocolOracle, defaults.ProtocolOpenSearch: + return nil, trace.BadParameter("please provide the database user name using the --db-user flag") + case defaults.ProtocolRedis: + // Default to "default" in the same way as Redis does. We need the username to check access on our side. + // ref: https://redis.io/commands/auth + log.Debugf("Defaulting to Redis username %q as database username.", defaults.DefaultRedisUsername) + dbInfo.Username = defaults.DefaultRedisUsername + } + } + if dbInfo.Database != "" { + switch dbInfo.Protocol { + case defaults.ProtocolDynamoDB: + log.Warnf("Database %v protocol %v does not support --db-name flag, ignoring --db-name=%v", + dbInfo.ServiceName, defaults.ReadableDatabaseProtocol(dbInfo.Protocol), dbInfo.Database) + dbInfo.Database = "" + } + } else { + switch dbInfo.Protocol { + // Always require db-name for Oracle Protocol. + case defaults.ProtocolOracle: + return nil, trace.BadParameter("please provide the database name using the --db-name flag") + } } - return databases[0], nil + return &dbInfo, nil } -func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, route *tlsca.RouteToDatabase, profile *client.ProfileStatus, requires *dbLocalProxyRequirement) (bool, error) { +func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, route tlsca.RouteToDatabase, profile *client.ProfileStatus, requires *dbLocalProxyRequirement) (bool, error) { if (requires.localProxy && requires.tunnel) || isLocalProxyTunnelRequested(cf) { switch route.Protocol { case defaults.ProtocolOracle: @@ -946,14 +942,14 @@ func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, route *tlsca.Ro // maybeDatabaseLogin checks if cert is still valid. If not valid, trigger db login logic. // returns a true/false indicating whether database login was triggered. -func maybeDatabaseLogin(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, route *tlsca.RouteToDatabase, requires *dbLocalProxyRequirement) error { - reloginNeeded, err := needDatabaseRelogin(cf, tc, route, profile, requires) +func maybeDatabaseLogin(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, dbInfo *databaseInfo, requires *dbLocalProxyRequirement) error { + reloginNeeded, err := needDatabaseRelogin(cf, tc, dbInfo.RouteToDatabase, profile, requires) if err != nil { return trace.Wrap(err) } if reloginNeeded { - return trace.Wrap(databaseLogin(cf, tc, *route)) + return trace.Wrap(databaseLogin(cf, tc, dbInfo)) } return nil } @@ -990,7 +986,7 @@ func dbInfoHasChanged(cf *CLIConf, certPath string) (bool, error) { // isMFADatabaseAccessRequired calls the IsMFARequired endpoint in order to get from user roles if access to the database // requires MFA. -func isMFADatabaseAccessRequired(ctx context.Context, tc *client.TeleportClient, database *tlsca.RouteToDatabase) (bool, error) { +func isMFADatabaseAccessRequired(ctx context.Context, tc *client.TeleportClient, database tlsca.RouteToDatabase) (bool, error) { proxy, err := tc.ConnectToProxy(ctx) if err != nil { return false, trace.Wrap(err) @@ -1146,13 +1142,9 @@ func (r *dbLocalProxyRequirement) addLocalProxyWithTunnel(reasons ...string) { r.tunnelReasons = append(r.tunnelReasons, reasons...) } -// requireOpt is an optional requirement function used when getting requirements, -// that allows the caller to add further requirements. -type requireOpt func(r *dbLocalProxyRequirement) - // getDBLocalProxyRequirement determines what local proxy settings are required // for a given database. -func getDBLocalProxyRequirement(tc *client.TeleportClient, route *tlsca.RouteToDatabase, opts ...requireOpt) *dbLocalProxyRequirement { +func getDBLocalProxyRequirement(tc *client.TeleportClient, route tlsca.RouteToDatabase) *dbLocalProxyRequirement { var out dbLocalProxyRequirement switch tc.PrivateKeyPolicy { case keys.PrivateKeyPolicyHardwareKey, keys.PrivateKeyPolicyHardwareKeyTouch: @@ -1162,7 +1154,7 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route *tlsca.RouteToD // When Proxy is behind a load balancer and the database requires the web // port, a local proxy must be used so the TLS routing request can be // upgraded, regardless whether Proxy is in single or separate port mode. - if tc.TLSRoutingConnUpgradeRequired && tc.DoesDatabaseUseWebProxyHostPort(*route) { + if tc.TLSRoutingConnUpgradeRequired && tc.DoesDatabaseUseWebProxyHostPort(route) { out.addLocalProxy("Teleport Proxy is behind a load balancer") } @@ -1179,49 +1171,44 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route *tlsca.RouteToD // When TLS routing is enabled and MySQL is listening on the web port, // a local proxy is required to connect. With a separate port, MySQL // does not require a local proxy even if TLS routing is enabled. - if tc.TLSRoutingEnabled && tc.DoesDatabaseUseWebProxyHostPort(*route) { + if tc.TLSRoutingEnabled && tc.DoesDatabaseUseWebProxyHostPort(route) { out.addLocalProxy(fmt.Sprintf("%v and %v", formatDBProtocolReason(route.Protocol), formatTLSRoutingReason(tc.SiteName))) } } - - for _, opt := range opts { - opt(&out) - } return &out } -// withConnectRequirements is requirement option fn that adds requirements specific to "tsh db connect". -func withConnectRequirements(ctx context.Context, tc *client.TeleportClient, route *tlsca.RouteToDatabase) requireOpt { - return func(r *dbLocalProxyRequirement) { - if !r.localProxy && tc.TLSRoutingEnabled { - r.addLocalProxy(formatTLSRoutingReason(tc.SiteName)) - } - switch route.Protocol { - case defaults.ProtocolElasticsearch, defaults.ProtocolOpenSearch: - // ElasticSearch and OpenSearch access can work without a local proxy tunnel, - // but not via `tsh db connect`. - // (elasticsearch-sql-cli and opensearchsql cannot be configured to use specific certs). - r.addLocalProxyWithTunnel(formatDBProtocolReason(route.Protocol)) - } - if r.localProxy && r.tunnel { - // don't check if MFA is required, because a local proxy tunnel is - // already required. this avoids an extra API call. - return - } - // Call API and check if a user needs to use MFA to connect to the database. - mfaRequired, err := isMFADatabaseAccessRequired(ctx, tc, route) - if err != nil { - log.WithError(err).Debugf("error getting MFA requirement for database %v", - route.ServiceName) - } else if mfaRequired { - // When MFA is required, we should require a local proxy tunnel, - // because the local proxy tunnel can hold database MFA certs in-memory - // without a restricted 1-minute TTL. This is better for user experience. - r.addLocalProxyWithTunnel("MFA is required to connect to the database") - } +func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase) *dbLocalProxyRequirement { + r := getDBLocalProxyRequirement(tc, route) + if !r.localProxy && tc.TLSRoutingEnabled { + r.addLocalProxy(formatTLSRoutingReason(tc.SiteName)) } + switch route.Protocol { + case defaults.ProtocolElasticsearch, defaults.ProtocolOpenSearch: + // ElasticSearch and OpenSearch access can work without a local proxy tunnel, + // but not via `tsh db connect`. + // (elasticsearch-sql-cli and opensearchsql cannot be configured to use specific certs). + r.addLocalProxyWithTunnel(formatDBProtocolReason(route.Protocol)) + } + if r.localProxy && r.tunnel { + // don't check if MFA is required, because a local proxy tunnel is + // already required. this avoids an extra API call. + return r + } + // Call API and check if a user needs to use MFA to connect to the database. + mfaRequired, err := isMFADatabaseAccessRequired(ctx, tc, route) + if err != nil { + log.WithError(err).Debugf("error getting MFA requirement for database %v", + route.ServiceName) + } else if mfaRequired { + // When MFA is required, we should require a local proxy tunnel, + // because the local proxy tunnel can hold database MFA certs in-memory + // without a restricted 1-minute TTL. This is better for user experience. + r.addLocalProxyWithTunnel("MFA is required to connect to the database") + } + return r } // formatKeyPolicyReason is a helper func that formats a private key policy "reason". @@ -1249,7 +1236,7 @@ func formatTLSRoutingReason(siteName string) string { // formatDbCmdUnsupported is a helper func that formats a generic unsupported DB error message. // The "reasons" arguments, if given, should specify condition for which this DB subcommand // is not supported, e.g. "TLS routing is enabled" or "using a local proxy without the --tunnel flag". -func formatDbCmdUnsupported(cf *CLIConf, route *tlsca.RouteToDatabase, reasons ...string) string { +func formatDbCmdUnsupported(cf *CLIConf, route tlsca.RouteToDatabase, reasons ...string) string { templateData := map[string]any{ "command": cf.CommandWithBinary(), "alternatives": getDbCmdAlternatives(cf.SiteName, route), @@ -1262,23 +1249,23 @@ func formatDbCmdUnsupported(cf *CLIConf, route *tlsca.RouteToDatabase, reasons . } // formatDbCmdUnsupportedDBProtocol is a helper func that formats an unsupported DB protocol error message. -func formatDbCmdUnsupportedDBProtocol(cf *CLIConf, route *tlsca.RouteToDatabase) string { +func formatDbCmdUnsupportedDBProtocol(cf *CLIConf, route tlsca.RouteToDatabase) string { reason := formatDBProtocolReason(route.Protocol) return formatDbCmdUnsupported(cf, route, reason) } // getDbCmdAlternatives is a helper func that returns alternative tsh commands for connecting to a database. -func getDbCmdAlternatives(clusterFlag string, route *tlsca.RouteToDatabase) []string { +func getDbCmdAlternatives(clusterFlag string, route tlsca.RouteToDatabase) []string { var alts []string switch route.Protocol { case defaults.ProtocolDynamoDB: // DynamoDB only works with a local proxy tunnel and there is no "shell-like" cli, so `tsh db connect` doesn't make sense. default: // prefer displaying the connect command as the first suggested command alternative. - alts = append(alts, formatDatabaseConnectCommand(clusterFlag, *route)) + alts = append(alts, formatDatabaseConnectCommand(clusterFlag, route)) } // all db protocols support this command. - alts = append(alts, formatDatabaseProxyCommand(clusterFlag, *route)) + alts = append(alts, formatDatabaseProxyCommand(clusterFlag, route)) return alts } diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 364b105849894..2a770bf60967a 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -318,13 +318,13 @@ func TestLocalProxyRequirement(t *testing.T) { if tt.setupTC != nil { tt.setupTC(tc) } - route := &tlsca.RouteToDatabase{ + route := tlsca.RouteToDatabase{ ServiceName: "foo-db", Protocol: "postgres", Username: "alice", Database: "postgres", } - requires := getDBLocalProxyRequirement(tc, route, withConnectRequirements(ctx, tc, route)) + requires := getDBConnectLocalProxyRequirement(ctx, tc, route) require.Equal(t, tt.wantLocalProxy, requires.localProxy) require.Equal(t, tt.wantTunnel, requires.tunnel) if requires.tunnel { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index d7b9774d305a1..d09da48a88caa 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -370,7 +370,7 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - route, db, err := getDatabaseInfo(cf, tc) + dbInfo, err := getDatabaseInfo(cf, tc) if err != nil { return trace.Wrap(err) } @@ -380,14 +380,14 @@ func onProxyCommandDB(cf *CLIConf) error { // 2. check if db login is required. // These steps are not needed with `--tunnel`, because the local proxy tunnel // will manage database certificates itself and reissue them as needed. - requires := getDBLocalProxyRequirement(tc, route) + requires := getDBLocalProxyRequirement(tc, dbInfo.RouteToDatabase) if requires.tunnel && !isLocalProxyTunnelRequested(cf) { // Some scenarios require a local proxy tunnel, e.g.: // - Snowflake, DynamoDB protocol // - Hardware-backed private key policy - return trace.BadParameter(formatDbCmdUnsupported(cf, route, requires.tunnelReasons...)) + return trace.BadParameter(formatDbCmdUnsupported(cf, dbInfo.RouteToDatabase, requires.tunnelReasons...)) } - if err := maybeDatabaseLogin(cf, tc, profile, route, requires); err != nil { + if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil { return trace.Wrap(err) } @@ -403,7 +403,7 @@ func onProxyCommandDB(cf *CLIConf) error { addr = fmt.Sprintf("127.0.0.1:%s", cf.LocalProxyPort) } - listener, err := createLocalProxyListener(addr, route, profile) + listener, err := createLocalProxyListener(addr, dbInfo.RouteToDatabase, profile) if err != nil { return trace.Wrap(err) } @@ -419,8 +419,7 @@ func onProxyCommandDB(cf *CLIConf) error { cf: cf, tc: tc, profile: profile, - route: *route, - database: db, + dbInfo: dbInfo, autoReissueCerts: cf.LocalProxyTunnel, // only auto-reissue certs for --tunnel flag. tunnel: tunnel, }) @@ -449,11 +448,11 @@ func onProxyCommandDB(cf *CLIConf) error { dbcmd.WithPrintFormat(), dbcmd.WithTolerateMissingCLIClient(), } - if opts, err = maybeAddDBUserPassword(db, opts); err != nil { + if opts, err = maybeAddDBUserPassword(cf, tc, dbInfo, opts); err != nil { return trace.Wrap(err) } - commands, err := dbcmd.NewCmdBuilder(tc, profile, route, rootCluster, + commands, err := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootCluster, opts..., ).GetConnectCommandAlternatives() if err != nil { @@ -462,14 +461,14 @@ func onProxyCommandDB(cf *CLIConf) error { // shared template arguments templateArgs := map[string]any{ - "database": route.ServiceName, - "type": defaults.ReadableDatabaseProtocol(route.Protocol), + "database": dbInfo.ServiceName, + "type": defaults.ReadableDatabaseProtocol(dbInfo.Protocol), "cluster": tc.SiteName, "address": listener.Addr().String(), "randomPort": randomPort, } - tmpl := chooseProxyCommandTemplate(templateArgs, commands, route.Protocol) + tmpl := chooseProxyCommandTemplate(templateArgs, commands, dbInfo.Protocol) err = tmpl.Execute(os.Stdout, templateArgs) if err != nil { return trace.Wrap(err) @@ -477,10 +476,10 @@ func onProxyCommandDB(cf *CLIConf) error { } else { err = dbProxyTpl.Execute(os.Stdout, map[string]any{ - "database": route.ServiceName, + "database": dbInfo.ServiceName, "address": listener.Addr().String(), "ca": profile.CACertPathForCluster(rootCluster), - "cert": profile.DatabaseCertPathForCluster(cf.SiteName, route.ServiceName), + "cert": profile.DatabaseCertPathForCluster(cf.SiteName, dbInfo.ServiceName), "key": profile.KeyPath(), "randomPort": randomPort, }) @@ -496,16 +495,22 @@ func onProxyCommandDB(cf *CLIConf) error { return nil } -func maybeAddDBUserPassword(db types.Database, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { - if db != nil && db.GetProtocol() == defaults.ProtocolCassandra && db.IsAWSHosted() { - // Cassandra client always prompt for password, so we need to provide it - // Provide an auto generated random password to skip the prompt in case of - // connection to AWS hosted cassandra. - password, err := utils.CryptoRandomHex(16) +func maybeAddDBUserPassword(cf *CLIConf, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { + if dbInfo.Protocol == defaults.ProtocolCassandra { + db, err := dbInfo.GetDatabase(cf, tc) if err != nil { return nil, trace.Wrap(err) } - return append(opts, dbcmd.WithPassword(password)), nil + if db.IsAWSHosted() { + // Cassandra client always prompt for password, so we need to provide it + // Provide an auto generated random password to skip the prompt in case of + // connection to AWS hosted cassandra. + password, err := utils.CryptoRandomHex(16) + if err != nil { + return nil, trace.Wrap(err) + } + return append(opts, dbcmd.WithPassword(password)), nil + } } return opts, nil }