diff --git a/tool/tsh/common/access_request.go b/tool/tsh/common/access_request.go index 8eb5be3709a8b..0e6f7e61cece5 100644 --- a/tool/tsh/common/access_request.go +++ b/tool/tsh/common/access_request.go @@ -469,7 +469,13 @@ func onRequestSearch(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - tableColumns = []string{"Name", "Hostname", "Labels", "Resource ID"} + + switch cf.ResourceKind { + case types.KindDatabase: + tableColumns = []string{"Database Name", "Labels", "Resource ID"} + default: + tableColumns = []string{"Name", "Hostname", "Labels", "Resource ID"} + } } var rows [][]string @@ -512,11 +518,21 @@ func onRequestSearch(cf *CLIConf) error { if r, ok := resource.(interface{ GetHostname() string }); ok { hostName = r.GetHostname() } - row = []string{ - common.FormatResourceName(resource, cf.Verbose), - hostName, - common.FormatLabels(resource.GetAllLabels(), cf.Verbose), - resourceID, + + switch cf.ResourceKind { + case types.KindDatabase: + row = []string{ + common.FormatResourceName(resource, cf.Verbose), + common.FormatLabels(resource.GetAllLabels(), cf.Verbose), + resourceID, + } + default: + row = []string{ + common.FormatResourceName(resource, cf.Verbose), + hostName, + common.FormatLabels(resource.GetAllLabels(), cf.Verbose), + resourceID, + } } } rows = append(rows, row) diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 8bc93ea29f129..2e55723bc15a5 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -831,7 +831,38 @@ func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, routes []tlsca.Rout } db, err := getDatabaseByNameOrDiscoveredName(cf, tc, routes) - if err != nil { + switch { + // If the database cannot be found, try again with UseSearchAsRoles. If + // the database is then found with UseSearchAsRoles, make an access request + // for it and elevate the user with the request ID upon approval. + // + // Note that the access request must be made before the database connection + // is made to avoid mangling the request with the database client tools. + // Thus the flow for auto database access request is different from SSH. + // + // Performance considerations: + // - For common scenarios where UseSearchAsRoles is not desired, it would + // be rare that cf.DatabaseName would be not found in the first API call + // so there won't be a second call usually. + // - accessChecker.GetAllowedSearchAsRoles can be checked to avoid the + // second API call but creating the access checker requires more calls. + // - The db commands do provide "--disable-access-request" to bypass the + // second call. If needed, we can add it to `tsh login` and profile yaml + // in the future. + case shouldRetryGetDatabaseUsingSearchAsRoles(cf, tc, err): + orgErr := err + if db, err = getDatabaseByNameOrDiscoveredNameUsingSearchAsRoles(cf, tc); err != nil { + return nil, trace.Wrap(orgErr) // Returns the original not found error. + } + if err := makeDatabaseAccessRequestAndWaitForApproval(cf, tc, db); err != nil { + return nil, trace.Wrap(err) + } + + // Reset routes. Once access requeset is approved, user certs are + // reissued with client.CertCacheDrop. + routes = nil + + case err != nil: return nil, trace.Wrap(err) } @@ -853,6 +884,56 @@ func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, routes []tlsca.Rout return info, nil } +var dbCommandsWithAccessRequestSupport = []string{ + "db login", + "proxy db", + "db connect", +} + +func shouldRetryGetDatabaseUsingSearchAsRoles(cf *CLIConf, tc *client.TeleportClient, getDatabaseError error) bool { + // If already using SearchAsRoles, nothing to retry. + if tc.UseSearchAsRoles { + return false + } + // Only retry when the database cannot be found. + if !trace.IsNotFound(getDatabaseError) { + return false + } + // Check if auto access request is disabled. + if cf.disableAccessRequest { + return false + } + // Check if the `tsh` command supports auto access request. + return slices.Contains(dbCommandsWithAccessRequestSupport, cf.command) +} + +func makeAccessRequestForDatabase(tc *client.TeleportClient, db types.Database) (types.AccessRequest, error) { + requestResourceIDs := []types.ResourceID{{ + ClusterName: tc.SiteName, + Kind: types.KindDatabase, + Name: db.GetName(), + }} + + req, err := services.NewAccessRequestWithResources(tc.Username, nil /* roles */, requestResourceIDs) + return req, trace.Wrap(err) +} + +func makeDatabaseAccessRequestAndWaitForApproval(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { + req, err := makeAccessRequestForDatabase(tc, db) + if err != nil { + return trace.Wrap(err) + } + + fmt.Fprintf(cf.Stdout(), "You do not currently have access to %q, attempting to request access.\n\n", db.GetName()) + if err := setAccessRequestReason(cf, req); err != nil { + return trace.Wrap(err) + } + if err := sendAccessRequestAndWaitForApproval(cf, tc, req); err != nil { + return trace.Wrap(err) + } + return nil +} + func requestedDatabaseRoles(cf *CLIConf) []string { if cf.DatabaseRoles == "" { return nil @@ -1043,6 +1124,15 @@ func getDatabaseByNameOrDiscoveredName(cf *CLIConf, tc *client.TeleportClient, a return chooseOneDatabase(cf, databases) } +func getDatabaseByNameOrDiscoveredNameUsingSearchAsRoles(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { + tc.UseSearchAsRoles = true + defer func() { + tc.UseSearchAsRoles = false + }() + db, err := getDatabaseByNameOrDiscoveredName(cf, tc, nil) + return db, trace.Wrap(err) +} + func filterActiveDatabases(routes []tlsca.RouteToDatabase, databases types.Databases) types.Databases { databasesByName := databases.ToMap() var out types.Databases @@ -1068,6 +1158,7 @@ func listDatabasesWithPredicate(ctx context.Context, tc *client.TeleportClient, ResourceType: types.KindDatabaseServer, PredicateExpression: predicate, Labels: tc.Labels, + UseSearchAsRoles: tc.UseSearchAsRoles, }) return trace.Wrap(err) }) diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index cbe13d26c232a..7c33fe2145aa4 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -43,16 +43,44 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + dbcommon "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/teleport/testenv" ) +func registerFakeEnterpriseDBEngines(t *testing.T) { + newFakeEngine := func(dbcommon.EngineConfig) dbcommon.Engine { + type fakeDBEngine struct { + dbcommon.Engine + } + return fakeDBEngine{} + } + dbcommon.RegisterEngine(newFakeEngine, defaults.ProtocolOracle) + t.Cleanup(func() { + dbcommon.RegisterEngine(nil, defaults.ProtocolOracle) + }) +} + func TestTshDB(t *testing.T) { + // Register missing Enterprise database engines. Otherwise db.CheckEngines + // will fail. The fake engine registered are not functional. But other + // Enterprise features like Access Request can still be tested. + registerFakeEnterpriseDBEngines(t) + modules.SetTestModules(t, + &modules.TestModules{ + TestBuildType: modules.BuildEnterprise, + TestFeatures: modules.Features{ + DB: true, + }, + }, + ) + // this speeds up test suite setup substantially, which is where // tests spend the majority of their time, especially when leaf // clusters are setup. @@ -81,16 +109,40 @@ func testDatabaseLogin(t *testing.T) { }) require.NoError(t, err) + devAccessRole, err := types.NewRole("dev-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{"env": []string{"dev"}}, + DatabaseNames: []string{"default"}, + DatabaseUsers: []string{"admin"}, + }, + }) + require.NoError(t, err) + + accessRequestorRole, err := types.NewRole("access-requestor", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + SearchAsRoles: []string{"access"}, + }, + }, + }) + require.NoError(t, err) + alice, err := types.NewUser("alice@example.com") require.NoError(t, err) // to use default --db-user and --db-name selection, make a user with just // one of each allowed. alice.SetDatabaseUsers([]string{"admin"}) alice.SetDatabaseNames([]string{"default"}) - alice.SetRoles([]string{"access", "autouser"}) + alice.SetRoles([]string{"dev-access", "autouser", "access-requestor"}) s := newTestSuite(t, withRootConfigFunc(func(cfg *servicecfg.Config) { - cfg.Auth.BootstrapResources = append(cfg.Auth.BootstrapResources, autoUserRole, alice) + cfg.Auth.BootstrapResources = append( + cfg.Auth.BootstrapResources, + autoUserRole, + devAccessRole, + accessRequestorRole, + alice, + ) cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) // separate MySQL port with TLS routing. // set the public address to be sure even on v2+, tsh clients will see the separate port. @@ -106,7 +158,7 @@ func testDatabaseLogin(t *testing.T) { StaticLabels: map[string]string{ types.DiscoveredNameLabel: "postgres", "region": "us-west-1", - "env": "prod", + "env": "dev", }, AWS: servicecfg.DatabaseAWS{ AccountID: "123456789012", @@ -116,9 +168,15 @@ func testDatabaseLogin(t *testing.T) { }, }, }, { - Name: "mysql", - Protocol: defaults.ProtocolMySQL, - URI: "localhost:3306", + Name: "mysql", + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + StaticLabels: map[string]string{"env": "dev"}, + }, { + Name: "mysql-prod", + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + StaticLabels: map[string]string{"env": "prod"}, }, { Name: "mysql-autouser", Protocol: defaults.ProtocolMySQL, @@ -130,25 +188,30 @@ func testDatabaseLogin(t *testing.T) { Name: "teleport-admin", }, }, { - Name: "cassandra", - Protocol: defaults.ProtocolCassandra, - URI: "localhost:9042", + Name: "cassandra", + Protocol: defaults.ProtocolCassandra, + URI: "localhost:9042", + StaticLabels: map[string]string{"env": "dev"}, }, { - Name: "snowflake", - Protocol: defaults.ProtocolSnowflake, - URI: "localhost.snowflakecomputing.com", + Name: "snowflake", + Protocol: defaults.ProtocolSnowflake, + URI: "localhost.snowflakecomputing.com", + StaticLabels: map[string]string{"env": "dev"}, }, { - Name: "mongo", - Protocol: defaults.ProtocolMongoDB, - URI: "localhost:27017", + Name: "mongo", + Protocol: defaults.ProtocolMongoDB, + URI: "localhost:27017", + StaticLabels: map[string]string{"env": "dev"}, }, { - Name: "mssql", - Protocol: defaults.ProtocolSQLServer, - URI: "localhost:1433", + Name: "mssql", + Protocol: defaults.ProtocolSQLServer, + URI: "localhost:1433", + StaticLabels: map[string]string{"env": "dev"}, }, { - Name: "dynamodb", - Protocol: defaults.ProtocolDynamoDB, - URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. + Name: "dynamodb", + Protocol: defaults.ProtocolDynamoDB, + URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. + StaticLabels: map[string]string{"env": "dev"}, AWS: servicecfg.DatabaseAWS{ AccountID: "123456789012", ExternalID: "123123123", @@ -173,11 +236,13 @@ func testDatabaseLogin(t *testing.T) { // extraLoginOptions is a list of extra options used for login like // `--db-roles`. extraLoginOptions []string + setAccessRequestState types.RequestState expectActiveRoute tlsca.RouteToDatabase expectCertsLen int expectKeysLen int expectErrForConfigCmd bool expectErrForEnvCmd bool + expectLoginErrorIs func(error) bool }{ { name: "mongo", @@ -291,7 +356,7 @@ func testDatabaseLogin(t *testing.T) { { name: "by query", databaseName: "postgres-rds-us-west-1-123456789012", - dbSelectors: []string{"--query", `labels.env=="prod" && labels.region == "us-west-1"`}, + dbSelectors: []string{"--query", `labels.env=="dev" && labels.region == "us-west-1"`}, expectActiveRoute: tlsca.RouteToDatabase{ ServiceName: "postgres-rds-us-west-1-123456789012", Protocol: "postgres", @@ -312,6 +377,30 @@ func testDatabaseLogin(t *testing.T) { }, expectCertsLen: 1, }, + { + name: "database not found", + databaseName: "db-not-found", + expectLoginErrorIs: trace.IsNotFound, + }, + { + name: "access request approved", + databaseName: "mysql-prod", + setAccessRequestState: types.RequestState_APPROVED, + expectActiveRoute: tlsca.RouteToDatabase{ + ServiceName: "mysql-prod", + Protocol: "mysql", + Username: "admin", + }, + expectCertsLen: 1, + expectErrForConfigCmd: false, + expectErrForEnvCmd: false, + }, + { + name: "access request denied", + databaseName: "mysql-prod", + setAccessRequestState: types.RequestState_DENIED, + expectLoginErrorIs: trace.IsAccessDenied, + }, } // Note: keystore currently races when multiple tsh clients work in the @@ -346,7 +435,19 @@ func testDatabaseLogin(t *testing.T) { "db", "login", }, selectors...) args = append(args, test.extraLoginOptions...) + + // Access request setup. + if test.setAccessRequestState != types.RequestState_NONE { + args = append(args, "--request-reason", test.name) + go updateAccessRequestForDB(t, s, test.name, test.databaseName, test.setAccessRequestState) + } + err := Run(context.Background(), args, cliOpts...) + if test.expectLoginErrorIs != nil { + require.Error(t, err) + require.True(t, test.expectLoginErrorIs(err)) + return + } require.NoError(t, err) // Fetch the active profile. @@ -408,6 +509,37 @@ func testDatabaseLogin(t *testing.T) { } } +func updateAccessRequestForDB(t *testing.T, s *suite, wantRequestReason, wantDBName string, updateState types.RequestState) { + var accessRequestID string + require.Eventually(t, func() bool { + filter := types.AccessRequestFilter{State: types.RequestState_PENDING} + accessRequests, err := s.root.GetAuthServer().GetAccessRequests(context.Background(), filter) + if err != nil { + return false + } + + for _, accessRequest := range accessRequests { + if accessRequest.GetRequestReason() != wantRequestReason { + continue + } + for _, resourceID := range accessRequest.GetRequestedResourceIDs() { + if resourceID.Kind == types.KindDatabase && + resourceID.Name == wantDBName { + accessRequestID = accessRequest.GetName() + return true + } + } + } + return false + }, 10*time.Second, 500*time.Millisecond, "waiting for access request") + + err := s.root.GetAuthServer().SetAccessRequestState(context.Background(), types.AccessRequestUpdate{ + RequestID: accessRequestID, + State: updateState, + }) + require.NoError(t, err) +} + func TestLocalProxyRequirement(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1758,3 +1890,77 @@ func testDatabaseSelection(t *testing.T) { } }) } + +func Test_shouldRetryGetDatabaseUsingSearchAsRoles(t *testing.T) { + tests := []struct { + name string + cf *CLIConf + tc *client.TeleportClient + inputError error + checkOutput require.BoolAssertionFunc + }{ + { + name: "tsh db connect", + cf: &CLIConf{ + command: "db connect", + }, + tc: &client.TeleportClient{}, + inputError: trace.NotFound("not found"), + checkOutput: require.True, + }, + { + name: "tsh db login", + cf: &CLIConf{ + command: "db connect", + }, + tc: &client.TeleportClient{}, + inputError: trace.NotFound("not found"), + checkOutput: require.True, + }, + { + name: "tsh proxy db", + cf: &CLIConf{ + command: "db connect", + }, + tc: &client.TeleportClient{}, + inputError: trace.NotFound("not found"), + checkOutput: require.True, + }, + { + name: "not NotFound error", + cf: &CLIConf{ + command: "db connect", + }, + tc: &client.TeleportClient{}, + inputError: trace.ConnectionProblem(fmt.Errorf("timed out"), "timed out"), + checkOutput: require.False, + }, + { + name: "not supported command", + cf: &CLIConf{ + command: "db env", + }, + tc: &client.TeleportClient{}, + inputError: trace.NotFound("not found"), + checkOutput: require.False, + }, + { + name: "access request disabled", + cf: &CLIConf{ + command: "db connect", + disableAccessRequest: true, + }, + tc: &client.TeleportClient{}, + inputError: trace.NotFound("not found"), + checkOutput: require.False, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + test.checkOutput(t, shouldRetryGetDatabaseUsingSearchAsRoles(test.cf, test.tc, test.inputError)) + }) + } +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 024149d15a5e9..e0f18aa0e7da7 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -841,6 +841,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { proxyDB.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) proxyDB.Flag("labels", labelHelp).StringVar(&cf.Labels) proxyDB.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + proxyDB.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) + proxyDB.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) proxyApp := proxy.Command("app", "Start local TLS proxy for app connection when using Teleport in single-port mode.") proxyApp.Arg("app", "The name of the application to start local proxy for").Required().StringVar(&cf.AppName) @@ -884,6 +886,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbLogin.Flag("db-user", "Database user to configure as default.").Short('u').StringVar(&cf.DatabaseUser) dbLogin.Flag("db-name", "Database name to configure as default.").Short('n').StringVar(&cf.DatabaseName) dbLogin.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles) + dbLogin.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) + dbLogin.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) dbLogout := db.Command("logout", "Remove database credentials.") dbLogout.Arg("db", "Database to remove credentials for.").StringVar(&cf.DatabaseService) dbLogout.Flag("labels", labelHelp).StringVar(&cf.Labels) @@ -910,6 +914,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbConnect.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles) dbConnect.Flag("labels", labelHelp).StringVar(&cf.Labels) dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) + dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) // join join := app.Command("join", "Join the active SSH or Kubernetes session.") @@ -3267,12 +3273,25 @@ func retryWithAccessRequest( log.Debugf("Not attempting to automatically request access, reason: %v", err) return trace.Wrap(origErr) } - cf.RequestID = req.GetName() // Print and log the original AccessDenied error. fmt.Fprintln(os.Stderr, utils.UserMessageFromError(origErr)) fmt.Fprintf(os.Stdout, "You do not currently have access to %q, attempting to request access.\n\n", resource) + if err := setAccessRequestReason(cf, req); err != nil { + return trace.Wrap(err) + } + if err := sendAccessRequestAndWaitForApproval(cf, tc, req); err != nil { + return trace.Wrap(err) + } + + // Retry now that request has been approved and certs updated. + // Clear the original exit status. + tc.ExitStatus = 0 + return trace.Wrap(fn()) +} + +func setAccessRequestReason(cf *CLIConf, req types.AccessRequest) (err error) { requestReason := cf.RequestReason if requestReason == "" { // Prompt for a request reason. @@ -3282,7 +3301,11 @@ func retryWithAccessRequest( } } req.SetRequestReason(requestReason) + return nil +} +func sendAccessRequestAndWaitForApproval(cf *CLIConf, tc *client.TeleportClient, req types.AccessRequest) (err error) { + cf.RequestID = req.GetName() fmt.Fprint(os.Stdout, "Creating request...\n") // Always create access request against the root cluster. if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { @@ -3313,11 +3336,7 @@ func retryWithAccessRequest( if err := onRequestResolution(cf, tc, resolvedReq); err != nil { return trace.Wrap(err) } - - // Retry now that request has been approved and certs updated. - // Clear the original exit status. - tc.ExitStatus = 0 - return trace.Wrap(fn()) + return nil } func onSSHLatency(cf *CLIConf) error { @@ -4583,6 +4602,9 @@ func onRequestResolution(cf *CLIConf, tc *client.TeleportClient, req types.Acces if reason := req.GetResolveReason(); reason != "" { msg = fmt.Sprintf("%s, reason=%q", msg, reason) } + if req.GetState().IsDenied() { + return trace.AccessDenied(msg) + } return trace.Errorf(msg) }