Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions tool/tsh/common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -1310,49 +1310,53 @@ func pickActiveDatabase(cf *CLIConf, tc *client.TeleportClient) (*tlsca.RouteToD
// filtered out - this is to avoid requiring additional selectors
// when a user gives an exact database name.
func filterActiveDatabases(ctx context.Context, tc *client.TeleportClient, activeRoutes []tlsca.RouteToDatabase) ([]tlsca.RouteToDatabase, types.Databases, error) {
prefix := tc.DatabaseService
if prefix == "" && len(activeRoutes) == 1 {
prefix = activeRoutes[0].ServiceName
if len(activeRoutes) == 0 {
// nothing to filter
return nil, nil, nil
}

haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != ""
var selectedRoutes []tlsca.RouteToDatabase
for _, db := range activeRoutes {
if db.ServiceName == prefix && !haveSelectors {
// short-circuit to select the exact match when we don't have
// label or predicate selectors.
return []tlsca.RouteToDatabase{db}, nil, nil
prefix := tc.DatabaseService
if len(tc.Labels) == 0 && tc.PredicateExpression == "" {
if prefix == "" && len(activeRoutes) == 1 {
return activeRoutes, nil, nil
}
if strings.HasPrefix(db.ServiceName, prefix) {
selectedRoutes = append(selectedRoutes, db)
// when we have a name but don't have label or predicate query, look for
// a route that matches the name exactly to maybe avoid calling
// ListDatabases API below.
for _, route := range activeRoutes {
if route.ServiceName == prefix {
return []tlsca.RouteToDatabase{route}, nil, nil
}
}
}
if len(selectedRoutes) == 0 || !haveSelectors {
// nothing to filter further, avoid making API call.
return selectedRoutes, nil, nil
}

// make a ListDatabases API call and match on full database name.
// make a ListDatabases API call filtered by prefix name
databases, err := listDatabasesByPrefix(ctx, tc, prefix)
if err != nil {
return nil, nil, trace.Wrap(err)
}
selectedRoutes = nil
databasesByName := databases.ToMap()

// when a database matches the prefix fully, look for a
// corresponding active route.
if db, ok := databasesByName[prefix]; ok {
for _, route := range activeRoutes {
if route.ServiceName == db.GetName() {
return []tlsca.RouteToDatabase{route}, types.Databases{db}, nil
}
}
// no active route, but return the fetched databases if the caller is
// interested.
return nil, databases, nil
}

// otherwise, just filter routes to those that match the names of the
// databases.
var selectedRoutes []tlsca.RouteToDatabase
var activeDBs types.Databases
for _, route := range activeRoutes {
for _, db := range databases {
if db.GetName() == route.ServiceName {
if db.GetName() == prefix {
// when label/query selectors are used and multiple
// databases come back, but one of them matches the prefix
// exactly, short-circuit to return just that db.
// We can't do that before calling the API because the
// labels/query might not actually match the active db.
return []tlsca.RouteToDatabase{route}, types.Databases{db}, nil
}
selectedRoutes = append(selectedRoutes, route)
activeDBs = append(activeDBs, db)
}
if db, ok := databasesByName[route.ServiceName]; ok {
selectedRoutes = append(selectedRoutes, route)
activeDBs = append(activeDBs, db)
}
}
return selectedRoutes, activeDBs, nil
Expand Down
49 changes: 42 additions & 7 deletions tool/tsh/common/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,15 +961,38 @@ func testFilterActiveDatabases(t *testing.T) {
dbNamePrefix,
labels,
query string
wantAPICall bool
wantRoutes []tlsca.RouteToDatabase
wantAPICall bool
overrideActiveRoutes []tlsca.RouteToDatabase
overrideAPIDatabasesCheckFn func(t *testing.T, databases types.Databases)
wantRoutes []tlsca.RouteToDatabase
}{
{
name: "by exact name that is a prefix of others",
dbNamePrefix: fooRoute1.ServiceName,
wantAPICall: false,
wantRoutes: []tlsca.RouteToDatabase{fooRoute1},
},
{
name: "by exact name of inactive route that is a prefix of active routes",
dbNamePrefix: fooRoute1.ServiceName,
overrideActiveRoutes: []tlsca.RouteToDatabase{
fooRoute2, fooRoute3,
barRoute1, barRoute2,
bazRoute1, bazRoute2,
},
wantAPICall: true,
overrideAPIDatabasesCheckFn: func(t *testing.T, databases types.Databases) {
t.Helper()
require.NotNil(t, databases)
databasesByName := databases.ToMap()
require.Contains(t, databasesByName, fooRoute1.ServiceName)
require.Contains(t, databasesByName, fooRoute2.ServiceName)
require.Contains(t, databasesByName, fooRoute3.ServiceName)
},
// the inactive route got filtered out, but active routes shouldn't
// have been matched by prefix either.
wantRoutes: nil,
},
{
name: "by exact name that is not a prefix of others",
dbNamePrefix: fooRoute2.ServiceName,
Expand All @@ -986,7 +1009,7 @@ func testFilterActiveDatabases(t *testing.T) {
{
name: "by name prefix",
dbNamePrefix: "ba",
wantAPICall: false,
wantAPICall: true,
wantRoutes: []tlsca.RouteToDatabase{barRoute1, barRoute2, bazRoute1, bazRoute2},
},
{
Expand Down Expand Up @@ -1046,12 +1069,24 @@ func testFilterActiveDatabases(t *testing.T) {
}
tc, err := makeClient(cf)
require.NoError(t, err)
routes, dbs, err := filterActiveDatabases(ctx, tc, routes)
activeRoutes := routes
if tt.overrideActiveRoutes != nil {
activeRoutes = tt.overrideActiveRoutes
}
gotRoutes, dbs, err := filterActiveDatabases(ctx, tc, activeRoutes)
require.NoError(t, err)
require.Empty(t, cmp.Diff(tt.wantRoutes, routes))
require.Empty(t, cmp.Diff(tt.wantRoutes, gotRoutes))
if tt.wantAPICall {
require.Equal(t, len(routes), len(dbs),
"returned routes should have corresponding types.Databases")
if tt.overrideAPIDatabasesCheckFn != nil {
tt.overrideAPIDatabasesCheckFn(t, dbs)
} else {
require.Equal(t, len(tt.wantRoutes), len(dbs),
"returned routes should have corresponding types.Databases")
for i := range tt.wantRoutes {
require.Equal(t, gotRoutes[i].ServiceName, dbs[i].GetName(),
"route %v does not match corresponding types.Database", i)
}
}
return
}
require.Zero(t, len(dbs), "unexpected API call to ListDatabases")
Expand Down