Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tool/tsh/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ func onDatabaseConnect(cf *CLIConf) error {
// is active in profile and no labels or predicate query are given.
// Otherwise, the ListDatabases endpoint is called.
func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient) (*databaseInfo, error) {
haveSelectors := len(tc.Labels) > 0 || tc.PredicateExpression != ""
haveSelectors := tc.DatabaseService != "" || len(tc.Labels) > 0 || tc.PredicateExpression != ""
if !haveSelectors {
// if selectors are given, we might incur an extra ListDatabases API
// call here to match against an active database.
Expand Down
171 changes: 105 additions & 66 deletions tool/tsh/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
Expand All @@ -58,7 +59,7 @@ func TestTshDB(t *testing.T) {
t.Run("Login", testDatabaseLogin)
t.Run("List", testListDatabase)
t.Run("FilterActiveDatabases", testFilterActiveDatabases)
t.Run("GetDatabase", testGetDatabase)
t.Run("DatabaseInfo", testDatabaseInfo)
}

// testDatabaseLogin tests "tsh db login" command and verifies "tsh db
Expand Down Expand Up @@ -931,7 +932,7 @@ func testFilterActiveDatabases(t *testing.T) {
}
}

func testGetDatabase(t *testing.T) {
func testDatabaseInfo(t *testing.T) {
t.Parallel()
alice, err := types.NewUser("alice@example.com")
require.NoError(t, err)
Expand All @@ -948,6 +949,13 @@ func testGetDatabase(t *testing.T) {
StaticLabels: map[string]string{
"env": "local",
},
}, {
Name: "postgres-2",
Protocol: defaults.ProtocolPostgres,
URI: "localhost:5432",
StaticLabels: map[string]string{
"env": "local",
},
}, {
Name: "mysql",
Protocol: defaults.ProtocolMySQL,
Expand Down Expand Up @@ -996,71 +1004,102 @@ func testGetDatabase(t *testing.T) {
tmpHomePath, _ := mustLogin(t, s)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, db := range databases {
require.NotEmpty(t, db.Name)
require.NotEmpty(t, db.Protocol)
route := tlsca.RouteToDatabase{
ServiceName: db.Name,
Protocol: db.Protocol,
Username: defaultDBUser,
Database: defaultDBName,
}
t.Run(route.ServiceName, func(t *testing.T) {
t.Run("with active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, &route)
require.NoError(t, err)
require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
if route.Protocol == defaults.ProtocolDynamoDB {
// v13 specific. We remove the dynamodb schema name from the route since it's not supported.
require.Equal(t, route.ServiceName, dbInfo.ServiceName)
require.Equal(t, route.Protocol, dbInfo.Protocol)
require.Equal(t, route.Username, dbInfo.Username)
} else {
require.Equal(t, route, dbInfo.RouteToDatabase)
}
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "database should have been fetched and cached")
})
t.Run("without active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
DatabaseService: route.ServiceName,
DatabaseUser: route.Username,
DatabaseName: route.Database,
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, nil)
require.NoError(t, err)
require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
if route.Protocol == defaults.ProtocolDynamoDB {
// v13 specific. We remove the dynamodb schema name from the route since it's not supported.
require.Equal(t, route.ServiceName, dbInfo.ServiceName)
require.Equal(t, route.Protocol, dbInfo.Protocol)
require.Equal(t, route.Username, dbInfo.Username)
} else {
require.Equal(t, route, dbInfo.RouteToDatabase)
}
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "cached database should be the same")
t.Run("newDatabaseInfo", func(t *testing.T) {
for _, db := range databases {
require.NotEmpty(t, db.Name)
require.NotEmpty(t, db.Protocol)
route := tlsca.RouteToDatabase{
ServiceName: db.Name,
Protocol: db.Protocol,
Username: defaultDBUser,
Database: defaultDBName,
}
t.Run(route.ServiceName, func(t *testing.T) {
t.Run("with active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
tracer: tracing.NoopTracer(teleport.ComponentTSH),
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, &route)
require.NoError(t, err)
require.Nil(t, dbInfo.database, "with an active cert the database should not have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
if route.Protocol == defaults.ProtocolDynamoDB {
// v13 specific. We remove the dynamodb schema name from the route since it's not supported.
require.Equal(t, route.ServiceName, dbInfo.ServiceName)
require.Equal(t, route.Protocol, dbInfo.Protocol)
require.Equal(t, route.Username, dbInfo.Username)
} else {
require.Equal(t, route, dbInfo.RouteToDatabase)
}
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "database should have been fetched and cached")
})
t.Run("without active db cert", func(t *testing.T) {
cf := &CLIConf{
Context: ctx,
TracingProvider: tracing.NoopProvider(),
HomePath: tmpHomePath,
tracer: tracing.NoopTracer(teleport.ComponentTSH),
DatabaseService: route.ServiceName,
DatabaseUser: route.Username,
DatabaseName: route.Database,
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := newDatabaseInfo(cf, tc, nil)
require.NoError(t, err)
require.NotNil(t, dbInfo.database, "without an active cert the database should have been fetched")
db, err := dbInfo.GetDatabase(cf, tc)
require.NoError(t, err)
if route.Protocol == defaults.ProtocolDynamoDB {
// v13 specific. We remove the dynamodb schema name from the route since it's not supported.
require.Equal(t, route.ServiceName, dbInfo.ServiceName)
require.Equal(t, route.Protocol, dbInfo.Protocol)
require.Equal(t, route.Username, dbInfo.Username)
} else {
require.Equal(t, route, dbInfo.RouteToDatabase)
}
require.Equal(t, route.ServiceName, db.GetName())
require.Equal(t, route.Protocol, db.GetProtocol())
require.Equal(t, dbInfo.database, db, "cached database should be the same")
})
})
})
}
}
})
t.Run("getDatabaseInfo", func(t *testing.T) {
// login to "postgres-2" db.
err = Run(ctx, []string{"db", "login", "postgres-2"}, setHomePath(tmpHomePath))
require.NoError(t, err)
cf := &CLIConf{
Context: ctx,
HomePath: tmpHomePath,
// select the other db, "postgres", which was not logged into.
DatabaseService: "postgres",
// v13 specific: set the db name/username because it won't be
// set by default until v14+.
DatabaseUser: defaultDBUser,
DatabaseName: defaultDBName,
}
tc, err := makeClient(cf)
require.NoError(t, err)
dbInfo, err := getDatabaseInfo(cf, tc)
require.NoError(t, err)
require.NotNil(t, dbInfo)
// verify that the active login route for "postgres-2" was not used
// instead of fetching info for the "postgres" db.
require.Equal(t, "postgres", dbInfo.ServiceName)
require.Equal(t, defaults.ProtocolPostgres, dbInfo.Protocol)
require.Equal(t, defaultDBUser, dbInfo.Username)
require.Equal(t, defaultDBName, dbInfo.Database)
require.NotNil(t, dbInfo.database)
})
}

func TestResourceSelectorsFormatting(t *testing.T) {
Expand Down