From 700f9f71e524e7817fadaab9e2a0648d09d6b49b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Smoli=C5=84ski?= Date: Wed, 6 Oct 2021 22:59:35 +0200 Subject: [PATCH] Add support for MFA for DB access (#8270) --- lib/auth/auth.go | 8 ++ lib/auth/auth_with_roles_test.go | 171 +++++++++++++++++++++++++++++ lib/client/api.go | 5 + lib/client/client.go | 9 +- lib/srv/db/common/role/role.go | 47 ++++++++ lib/srv/db/mongodb/engine.go | 14 ++- lib/srv/db/mysql/engine.go | 19 ++-- lib/srv/db/postgres/engine.go | 12 ++- tool/tsh/db.go | 177 ++++++++++++++++++++++--------- tool/tsh/db_test.go | 39 ------- tool/tsh/tsh.go | 6 -- 11 files changed, 385 insertions(+), 122 deletions(-) create mode 100644 lib/srv/db/common/role/role.go delete mode 100644 tool/tsh/db_test.go diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 2f15d1f45583f..9783cca88eebb 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -71,6 +71,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" @@ -3115,9 +3116,16 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck if db == nil { return nil, trace.Wrap(notFoundErr) } + + dbRoleMatchers := role.DatabaseRoleMatchers( + db.GetProtocol(), + t.Database.Username, + t.Database.GetDatabase(), + ) noMFAAccessErr = checker.CheckAccess( db, services.AccessMFAParams{}, + dbRoleMatchers..., ) default: diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index aadbc9ca3b02e..73194da9fa5ac 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" libdefaults "github.com/gravitational/teleport/lib/defaults" @@ -955,6 +956,176 @@ func TestReplaceRemoteLocksRBAC(t *testing.T) { } } +// TestIsMFARequiredMFADB tests isMFARequest logic per database protocol where different role matchers are used. +func TestIsMFARequiredMFADB(t *testing.T) { + const ( + databaseName = "test-database" + userName = "test-username" + ) + + type modifyRoleFunc func(role types.Role) + tests := []struct { + name string + userRoleRequireMFA bool + checkMFA require.BoolAssertionFunc + modifyRoleFunc modifyRoleFunc + dbProtocol string + req *proto.IsMFARequiredRequest + }{ + { + name: "RequireSessionMFA enabled MySQL protocol doesn't match database name", + dbProtocol: libdefaults.ProtocolMySQL, + req: &proto.IsMFARequiredRequest{ + Target: &proto.IsMFARequiredRequest_Database{ + Database: &proto.RouteToDatabase{ + ServiceName: databaseName, + Protocol: libdefaults.ProtocolMySQL, + Username: userName, + Database: "example", + }, + }, + }, + modifyRoleFunc: func(role types.Role) { + roleOpt := role.GetOptions() + roleOpt.RequireSessionMFA = true + role.SetOptions(roleOpt) + + role.SetDatabaseUsers(types.Allow, []string{types.Wildcard}) + role.SetDatabaseLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) + role.SetDatabaseNames(types.Allow, nil) + }, + checkMFA: require.True, + }, + { + name: "RequireSessionMFA disabled", + dbProtocol: libdefaults.ProtocolMySQL, + req: &proto.IsMFARequiredRequest{ + Target: &proto.IsMFARequiredRequest_Database{ + Database: &proto.RouteToDatabase{ + ServiceName: databaseName, + Protocol: libdefaults.ProtocolMySQL, + Username: userName, + Database: "example", + }, + }, + }, + modifyRoleFunc: func(role types.Role) { + roleOpt := role.GetOptions() + roleOpt.RequireSessionMFA = false + role.SetOptions(roleOpt) + + role.SetDatabaseUsers(types.Allow, []string{types.Wildcard}) + role.SetDatabaseLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) + role.SetDatabaseNames(types.Allow, nil) + }, + checkMFA: require.False, + }, + { + name: "RequireSessionMFA enabled Postgres protocol database name doesn't match", + dbProtocol: libdefaults.ProtocolPostgres, + req: &proto.IsMFARequiredRequest{ + Target: &proto.IsMFARequiredRequest_Database{ + Database: &proto.RouteToDatabase{ + ServiceName: databaseName, + Protocol: libdefaults.ProtocolPostgres, + Username: userName, + Database: "example", + }, + }, + }, + modifyRoleFunc: func(role types.Role) { + roleOpt := role.GetOptions() + roleOpt.RequireSessionMFA = true + role.SetOptions(roleOpt) + + role.SetDatabaseUsers(types.Allow, []string{types.Wildcard}) + role.SetDatabaseLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) + role.SetDatabaseNames(types.Allow, nil) + }, + checkMFA: require.False, + }, + { + name: "RequireSessionMFA enabled Postgres protocol database name matches", + dbProtocol: libdefaults.ProtocolPostgres, + req: &proto.IsMFARequiredRequest{ + Target: &proto.IsMFARequiredRequest_Database{ + Database: &proto.RouteToDatabase{ + ServiceName: databaseName, + Protocol: libdefaults.ProtocolPostgres, + Username: userName, + Database: "example", + }, + }, + }, + modifyRoleFunc: func(role types.Role) { + roleOpt := role.GetOptions() + roleOpt.RequireSessionMFA = true + role.SetOptions(roleOpt) + + role.SetDatabaseUsers(types.Allow, []string{types.Wildcard}) + role.SetDatabaseLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) + role.SetDatabaseNames(types.Allow, []string{"example"}) + }, + checkMFA: require.True, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + srv := newTestTLSServer(t) + + // Enable MFA support. + authPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + U2F: &types.U2F{ + AppID: "teleport", + Facets: []string{"teleport"}, + }, + }) + require.NoError(t, err) + err = srv.Auth().SetAuthPreference(ctx, authPref) + require.NoError(t, err) + + database, err := types.NewDatabaseServerV3( + types.Metadata{ + Name: databaseName, + Labels: map[string]string{ + "env": "dev", + }, + }, + types.DatabaseServerSpecV3{ + Protocol: tc.dbProtocol, + URI: "example.com", + Hostname: "host", + HostID: "hostID", + }, + ) + require.NoError(t, err) + + _, err = srv.Auth().UpsertDatabaseServer(ctx, database) + require.NoError(t, err) + + user, role, err := CreateUserAndRole(srv.Auth(), userName, []string{"test-role"}) + require.NoError(t, err) + + if tc.modifyRoleFunc != nil { + tc.modifyRoleFunc(role) + } + err = srv.Auth().UpsertRole(ctx, role) + require.NoError(t, err) + + cl, err := srv.NewClient(TestUser(user.GetName())) + require.NoError(t, err) + + resp, err := cl.IsMFARequired(ctx, tc.req) + require.NoError(t, err) + tc.checkMFA(t, resp.GetRequired()) + }) + } +} + // TestKindClusterConfig verifies that types.KindClusterConfig can be used // as an alternative privilege to provide access to cluster configuration // resources. diff --git a/lib/client/api.go b/lib/client/api.go index 32208d7554db8..ad1ef44433ea7 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -600,6 +600,11 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error) if err != nil { return nil, trace.Wrap(err) } + // If the cert expiration time is less than 5s consider cert as expired and don't add + // it to the user profile as an active database. + if time.Until(cert.NotAfter) < 5*time.Second { + continue + } if tlsID.RouteToDatabase.ServiceName != "" { databases = append(databases, tlsID.RouteToDatabase) } diff --git a/lib/client/client.go b/lib/client/client.go index ed73c5171219d..8ec06b21230ca 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -164,14 +164,7 @@ func (p ReissueParams) usage() proto.UserCertsRequest_CertUsage { case p.RouteToDatabase.ServiceName != "": // Database means a request for a TLS certificate for access to a // specific database, as specified by RouteToDatabase. - - // DELETE IN 7.0 - // Database certs have to be requested with CertUsage All because - // pre-7.0 servers do not accept usage-restricted certificates. - // - // In 7.0 clients, we can expect the server to be 7.0+ and set this to - // proto.UserCertsRequest_Database again. - return proto.UserCertsRequest_All + return proto.UserCertsRequest_Database case p.RouteToApp.Name != "": // App means a request for a TLS certificate for access to a specific // web app, as specified by RouteToApp. diff --git a/lib/srv/db/common/role/role.go b/lib/srv/db/common/role/role.go new file mode 100644 index 0000000000000..8a0df428375c1 --- /dev/null +++ b/lib/srv/db/common/role/role.go @@ -0,0 +1,47 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package role + +import ( + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" +) + +// DatabaseRoleMatchers returns role matchers based on the database protocol. +func DatabaseRoleMatchers(dbProtocol string, user, database string) services.RoleMatchers { + switch dbProtocol { + case defaults.ProtocolMySQL: + // In MySQL, unlike Postgres, "database" and "schema" are the same thing + // and there's no good way to prevent users from performing cross-database + // queries once they're connected, apart from granting proper privileges + // in MySQL itself. + // + // As such, checking db_names for MySQL is quite pointless so we only + // check db_users. In future, if we implement some sort of access controls + // on queries, we might be able to restrict db_names as well e.g. by + // detecting full-qualified table names like db.table, until then the + // proper way is to use MySQL grants system. + return services.RoleMatchers{ + &services.DatabaseUserMatcher{User: user}, + } + default: + return services.RoleMatchers{ + &services.DatabaseUserMatcher{User: user}, + &services.DatabaseNameMatcher{Name: database}, + } + } +} diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 42b630d595363..4269e742e40f7 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -21,8 +21,10 @@ import ( "net" "strings" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/srv/db/mongodb/protocol" "github.com/gravitational/teleport/lib/utils" @@ -188,10 +190,16 @@ func (e *Engine) authorizeClientMessage(sessionCtx *common.Session, message prot e.Log.Warnf("No database info in message: %v.", message) return nil } - err := sessionCtx.Checker.CheckAccess(sessionCtx.Database, + dbRoleMatchers := role.DatabaseRoleMatchers( + defaults.ProtocolMongoDB, + sessionCtx.DatabaseUser, + database, + ) + err := sessionCtx.Checker.CheckAccess( + sessionCtx.Database, services.AccessMFAParams{Verified: true}, - &services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser}, - &services.DatabaseNameMatcher{Name: database}) + dbRoleMatchers..., + ) e.Audit.OnQuery(e.Context, sessionCtx, common.Query{ Database: msg.GetDatabase(), // Commands may consist of multiple bson documents. diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index 8ae36579dc434..65b5460275a24 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/srv/db/mysql/protocol" "github.com/gravitational/teleport/lib/utils" @@ -128,20 +129,16 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er Verified: sessionCtx.Identity.MFAVerified != "", AlwaysRequired: ap.GetRequireSessionMFA(), } - // In MySQL, unlike Postgres, "database" and "schema" are the same thing - // and there's no good way to prevent users from performing cross-database - // queries once they're connected, apart from granting proper privileges - // in MySQL itself. - // - // As such, checking db_names for MySQL is quite pointless so we only - // check db_users. In future, if we implement some sort of access controls - // on queries, we might be able to restrict db_names as well e.g. by - // detecting full-qualified table names like db.table, until then the - // proper way is to use MySQL grants system. + dbRoleMatchers := role.DatabaseRoleMatchers( + defaults.ProtocolMySQL, + sessionCtx.DatabaseUser, + sessionCtx.DatabaseName, + ) err = sessionCtx.Checker.CheckAccess( sessionCtx.Database, mfaParams, - &services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser}) + dbRoleMatchers..., + ) if err != nil { e.Audit.OnSessionStart(e.Context, sessionCtx, err) return trace.Wrap(err) diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index e060cc93e334b..6980e443dae16 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -23,8 +23,10 @@ import ( "net" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" @@ -179,11 +181,17 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er Verified: sessionCtx.Identity.MFAVerified != "", AlwaysRequired: ap.GetRequireSessionMFA(), } + + dbRoleMatchers := role.DatabaseRoleMatchers( + defaults.ProtocolPostgres, + sessionCtx.DatabaseUser, + sessionCtx.DatabaseName, + ) err = sessionCtx.Checker.CheckAccess( sessionCtx.Database, mfaParams, - &services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser}, - &services.DatabaseNameMatcher{Name: sessionCtx.DatabaseName}) + dbRoleMatchers..., + ) if err != nil { e.Audit.OnSessionStart(e.Context, sessionCtx, err) return trace.Wrap(err) diff --git a/tool/tsh/db.go b/tool/tsh/db.go index 5d920dc98e444..dc51da90a5780 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -52,11 +52,6 @@ func onListDatabases(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - // Refresh the creds in case user was logged into any databases. - err = fetchDatabaseCreds(cf, tc) - if err != nil { - return trace.Wrap(err) - } // Retrieve profile to be able to show which databases user is logged into. profile, err := client.StatusCurrent(cf.HomePath, cf.Proxy) if err != nil { @@ -75,26 +70,13 @@ func onDatabaseLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - var databases []types.Database - err = client.RetryWithRelogin(cf.Context, tc, func() error { - allDatabases, err := tc.ListDatabases(cf.Context) - for _, database := range allDatabases { - if database.GetName() == cf.DatabaseService { - databases = append(databases, database) - } - } - return trace.Wrap(err) - }) + database, err := getDatabase(cf, tc, cf.DatabaseService) if err != nil { return trace.Wrap(err) } - if len(databases) == 0 { - return trace.NotFound( - "database %q not found, use 'tsh db ls' to see registered databases", cf.DatabaseService) - } err = databaseLogin(cf, tc, tlsca.RouteToDatabase{ ServiceName: cf.DatabaseService, - Protocol: databases[0].GetProtocol(), + Protocol: database.GetProtocol(), Username: cf.DatabaseUser, Database: cf.DatabaseName, }, false) @@ -116,19 +98,27 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatab if err != nil { return trace.Wrap(err) } - err = tc.ReissueUserCerts(cf.Context, client.CertCacheKeep, client.ReissueParams{ - RouteToCluster: tc.SiteName, - RouteToDatabase: proto.RouteToDatabase{ - ServiceName: db.ServiceName, - Protocol: db.Protocol, - Username: db.Username, - Database: db.Database, - }, - AccessRequests: profile.ActiveRequests.AccessRequests, - }) - if err != nil { + + var key *client.Key + if err = client.RetryWithRelogin(cf.Context, tc, func() error { + key, err = tc.IssueUserCertsWithMFA(cf.Context, client.ReissueParams{ + RouteToCluster: tc.SiteName, + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: db.ServiceName, + Protocol: db.Protocol, + Username: db.Username, + Database: db.Database, + }, + AccessRequests: profile.ActiveRequests.AccessRequests, + }) + return trace.Wrap(err) + }); err != nil { + return trace.Wrap(err) + } + if _, err = tc.LocalAgent().AddKey(key); err != nil { return trace.Wrap(err) } + // Refresh the profile. profile, err = client.StatusCurrent(cf.HomePath, cf.Proxy) if err != nil { @@ -146,27 +136,6 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatab return nil } -// fetchDatabaseCreds is called as a part of tsh login to refresh database -// access certificates for databases the current profile is logged into. -func fetchDatabaseCreds(cf *CLIConf, tc *client.TeleportClient) error { - profile, err := client.StatusCurrent(cf.HomePath, cf.Proxy) - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } - if trace.IsNotFound(err) { - return nil // No currently logged in profiles. - } - for _, db := range profile.Databases { - if err := databaseLogin(cf, tc, db, true); err != nil { - log.WithError(err).Errorf("Failed to fetch database access certificate for %s.", db) - if err := databaseLogout(tc, db); err != nil { - log.WithError(err).Errorf("Failed to log out of database %s.", db) - } - } - } - return nil -} - // onDatabaseLogout implements "tsh db logout" command. func onDatabaseLogout(cf *CLIConf) error { tc, err := makeClient(cf, false) @@ -321,10 +290,20 @@ func onDatabaseConnect(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - database, err := pickActiveDatabase(cf) + database, err := getDatabaseInfo(cf, tc, cf.DatabaseService) + if err != nil { + return trace.Wrap(err) + } + // Check is cert is still valid or DB connection requires MFA. If yes trigger db login logic. + relogin, err := needRelogin(cf, tc, database, profile) if err != nil { return trace.Wrap(err) } + if relogin { + if err := databaseLogin(cf, tc, *database, true); err != nil { + return trace.Wrap(err) + } + } var opts []ConnectCommandFunc if tc.ALPNSNIListenerEnabled { lp, err := startLocalALPNSNIProxy(cf, tc, database.Protocol) @@ -355,6 +334,98 @@ func onDatabaseConnect(cf *CLIConf) error { return nil } +// getDatabaseInfo fetches information about the database from tsh profile is DB is active in profile. Otherwise, +// the ListDatabases endpoint is called. +func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, dbName string) (*tlsca.RouteToDatabase, error) { + database, err := pickActiveDatabase(cf) + if err == nil { + return database, nil + } + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + db, err := getDatabase(cf, tc, dbName) + if err != nil { + return nil, trace.Wrap(err) + } + return &tlsca.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: db.GetProtocol(), + Username: cf.Username, + Database: cf.DatabaseName, + }, nil +} + +func getDatabase(cf *CLIConf, tc *client.TeleportClient, dbName string) (types.Database, error) { + var databases []types.Database + err := client.RetryWithRelogin(cf.Context, tc, func() error { + allDatabases, err := tc.ListDatabases(cf.Context) + for _, database := range allDatabases { + if database.GetName() == dbName { + databases = append(databases, database) + } + } + return trace.Wrap(err) + }) + if err != nil { + return nil, trace.Wrap(err) + } + if len(databases) == 0 { + return nil, trace.NotFound( + "database %q not found, use 'tsh db ls' to see registered databases", dbName) + } + return databases[0], nil +} + +func needRelogin(cf *CLIConf, tc *client.TeleportClient, database *tlsca.RouteToDatabase, profile *client.ProfileStatus) (bool, error) { + found := false + for _, v := range profile.Databases { + if v.ServiceName == database.ServiceName { + found = true + } + } + // database not found in active list of databases. + if !found { + return true, nil + } + // Call API and check is a user needs to use MFA to connect to the database. + mfaRequired, err := isMFADatabaseAccessRequired(cf, tc, database) + if err != nil { + return false, trace.Wrap(err) + } + return mfaRequired, nil +} + +// isMFADatabaseAccessRequired calls the IsMFARequired endpoint in order to get from user roles if access to the database +// requires MFA. +func isMFADatabaseAccessRequired(cf *CLIConf, tc *client.TeleportClient, database *tlsca.RouteToDatabase) (bool, error) { + proxy, err := tc.ConnectToProxy(cf.Context) + if err != nil { + return false, trace.Wrap(err) + } + cluster, err := proxy.ConnectToCluster(cf.Context, tc.SiteName, true) + if err != nil { + return false, trace.Wrap(err) + } + defer cluster.Close() + + dbParam := proto.RouteToDatabase{ + ServiceName: database.ServiceName, + Protocol: database.Protocol, + Username: cf.Username, + Database: database.Database, + } + mfaResp, err := cluster.IsMFARequired(cf.Context, &proto.IsMFARequiredRequest{ + Target: &proto.IsMFARequiredRequest_Database{ + Database: &dbParam, + }, + }) + if err != nil { + return false, trace.Wrap(err) + } + return mfaResp.GetRequired(), nil +} + // pickActiveDatabase returns the database the current profile is logged into. // // If logged into multiple databases, returns an error unless one specified diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go deleted file mode 100644 index 88780669224d7..0000000000000 --- a/tool/tsh/db_test.go +++ /dev/null @@ -1,39 +0,0 @@ -/* -Copyright 2021 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "testing" - - "github.com/pborman/uuid" - "github.com/stretchr/testify/require" -) - -// TestFetchDatabaseCreds makes sure fetching database credentials does not -// trigger an error when there's no logged in profile. -func TestFetchDatabaseCreds(t *testing.T) { - var cf CLIConf - cf.UserHost = "localhost" - // Randomize proxy name to make sure there's no profile entry. - cf.Proxy = uuid.New() - - tc, err := makeClient(&cf, true) - require.NoError(t, err) - - err = fetchDatabaseCreds(&cf, tc) - require.NoError(t, err) -} diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index b44a64c66cd77..887bf60cef802 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -934,12 +934,6 @@ func onLogin(cf *CLIConf) error { webProxyHost, _ := tc.WebProxyHostPort() cf.Proxy = webProxyHost - // If the profile is already logged into any database services, - // refresh the creds. - if err := fetchDatabaseCreds(cf, tc); err != nil { - return trace.Wrap(err) - } - // Print status to show information of the logged in user. return trace.Wrap(onStatus(cf)) }