diff --git a/tool/tsh/db.go b/tool/tsh/db.go index 16dfa86cc699b..fc482c38ed270 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -779,7 +779,7 @@ func onDatabaseConnect(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - route, database, err := getDatabaseInfo(cf, tc, cf.DatabaseService) + route, database, err := getDatabaseInfo(cf, tc) if err != nil { return trace.Wrap(err) } @@ -836,7 +836,7 @@ func onDatabaseConnect(cf *CLIConf) error { // getDatabaseInfo fetches information about the database from tsh profile is DB is active in profile. Otherwise, // the ListDatabases endpoint is called. -func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, dbName string) (*tlsca.RouteToDatabase, types.Database, error) { +func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToDatabase, types.Database, error) { database, err := pickActiveDatabase(cf) if err == nil { switch database.Protocol { @@ -850,14 +850,14 @@ func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, dbName string) (*tl if err != nil && !trace.IsNotFound(err) { return nil, nil, trace.Wrap(err) } - db, err := getDatabase(cf, tc, dbName) - if err != nil { - return nil, nil, trace.Wrap(err) - } + dbService := cf.DatabaseService username := cf.DatabaseUser databaseName := cf.DatabaseName if database != nil { + if dbService == "" { + dbService = database.ServiceName + } if username == "" { username = database.Username } @@ -866,6 +866,11 @@ func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, dbName string) (*tl } } + db, err := getDatabase(cf, tc, dbService) + if err != nil { + return nil, nil, trace.Wrap(err) + } + return &tlsca.RouteToDatabase{ ServiceName: db.GetName(), Protocol: db.GetProtocol(), diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index d85674b4b4377..d7b9774d305a1 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -370,7 +370,7 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - route, db, err := getDatabaseInfo(cf, tc, cf.DatabaseService) + route, db, err := getDatabaseInfo(cf, tc) if err != nil { return trace.Wrap(err) }