Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tool/tsh/common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ func onDatabaseConnect(cf *CLIConf) error {
// is active in profile and no labels or predicate query are given.
// Otherwise, the ListDatabases endpoint is called.
func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) {
haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != ""
haveSelectors := tc.DatabaseService != "" || len(tc.Labels) > 0 || tc.PredicateExpression != ""
Comment thread
GavinFrazar marked this conversation as resolved.
if !haveSelectors {
// if selectors are given, we might incur an extra ListDatabases API
// call here to match against an active database.
Expand Down
140 changes: 86 additions & 54 deletions tool/tsh/common/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestTshDB(t *testing.T) {
t.Run("Login", testDatabaseLogin)
t.Run("List", testListDatabase)
t.Run("FilterActiveDatabases", testFilterActiveDatabases)
t.Run("GetDatabase", testGetDatabase)
t.Run("DatabaseInfo", testDatabaseInfo)
}

// testDatabaseLogin tests "tsh db login" command and verifies "tsh db
Expand Down Expand Up @@ -1059,7 +1059,7 @@ func testFilterActiveDatabases(t *testing.T) {
}
}

func testGetDatabase(t *testing.T) {
func testDatabaseInfo(t *testing.T) {
t.Parallel()
alice, err := types.NewUser("alice@example.com")
require.NoError(t, err)
Expand All @@ -1076,6 +1076,13 @@ func testGetDatabase(t *testing.T) {
StaticLabels: map[string]string{
"env": "local",
},
}, {
Name: "postgres-2",
Protocol: defaults.ProtocolPostgres,
URI: "localhost:5432",
StaticLabels: map[string]string{
"env": "local",
},
}, {
Name: "mysql",
Protocol: defaults.ProtocolMySQL,
Expand Down Expand Up @@ -1124,59 +1131,84 @@ func testGetDatabase(t *testing.T) {
tmpHomePath, _ := mustLogin(t, s)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, db := range databases {
require.NotEmpty(t, db.Name)
require.NotEmpty(t, db.Protocol)
route := tlsca.RouteToDatabase{
ServiceName: db.Name,
Protocol: db.Protocol,
Username: defaultDBUser,
Database: defaultDBName,
}
t.Run(route.ServiceName, func(t *testing.T) {
t.Run("with active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
tracer: tracing.NoopTracer(teleport.ComponentTSH),
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, &route)
require.NoError(t, err)
require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
require.Equal(t, route, dbInfo.RouteToDatabase)
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "database should have been fetched and cached")
})
t.Run("without active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
tracer: tracing.NoopTracer(teleport.ComponentTSH),
DatabaseService: route.ServiceName,
DatabaseUser: route.Username,
DatabaseName: route.Database,
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, nil)
require.NoError(t, err)
require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
require.Equal(t, route, dbInfo.RouteToDatabase)
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "cached database should be the same")
t.Run("newDatabaseInfo", func(t *testing.T) {
for _, db := range databases {
require.NotEmpty(t, db.Name)
require.NotEmpty(t, db.Protocol)
route := tlsca.RouteToDatabase{
ServiceName: db.Name,
Protocol: db.Protocol,
Username: defaultDBUser,
Database: defaultDBName,
}
t.Run(route.ServiceName, func(t *testing.T) {
t.Run("with active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
tracer: tracing.NoopTracer(teleport.ComponentTSH),
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, &route)
require.NoError(t, err)
require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
require.Equal(t, route, dbInfo.RouteToDatabase)
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "database should have been fetched and cached")
})
t.Run("without active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
tracer: tracing.NoopTracer(teleport.ComponentTSH),
DatabaseService: route.ServiceName,
DatabaseUser: route.Username,
DatabaseName: route.Database,
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, nil)
require.NoError(t, err)
require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
require.Equal(t, route, dbInfo.RouteToDatabase)
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "cached database should be the same")
})
})
})
}
}
})
t.Run("getDatabaseInfo", func(t *testing.T) {
// login to "postgres-2" db.
err = Run(ctx, []string{"db", "login", "postgres-2"}, setHomePath(tmpHomePath))
require.NoError(t, err)
cf := &CLIConf{
Context: ctx,
HomePath: tmpHomePath,
// select the other db, "postgres", which was not logged into.
DatabaseService: "postgres",
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := getDatabaseInfo(cf, tc)
require.NoError(t, err)
require.NotNil(t, dbInfo)
// verify that the active login route for "postgres-2" was not used
// instead of fetching info for the "postgres" db.
require.Equal(t, "postgres", dbInfo.ServiceName)
require.Equal(t, defaults.ProtocolPostgres, dbInfo.Protocol)
require.Equal(t, defaultDBUser, dbInfo.Username)
require.Equal(t, defaultDBName, dbInfo.Database)
require.NotNil(t, dbInfo.database)
})
}

func TestResourceSelectorsFormatting(t *testing.T) {
Expand Down