diff --git a/lib/srv/db/cassandra/engine.go b/lib/srv/db/cassandra/engine.go index f3b8de60bd3cf..e5c8583720863 100644 --- a/lib/srv/db/cassandra/engine.go +++ b/lib/srv/db/cassandra/engine.go @@ -264,11 +264,11 @@ func (e *Engine) authorizeConnection(ctx context.Context) error { } state := e.sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + }) err = e.sessionCtx.Checker.CheckAccess( e.sessionCtx.Database, state, diff --git a/lib/srv/db/clickhouse/engine.go b/lib/srv/db/clickhouse/engine.go index 9e39290c34dd3..456699254f6f0 100644 --- a/lib/srv/db/clickhouse/engine.go +++ b/lib/srv/db/clickhouse/engine.go @@ -88,11 +88,11 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er } state := sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - sessionCtx.Database, - sessionCtx.DatabaseUser, - sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: sessionCtx.Database, + DatabaseUser: sessionCtx.DatabaseUser, + DatabaseName: sessionCtx.DatabaseName, + }) err = sessionCtx.Checker.CheckAccess( sessionCtx.Database, state, diff --git a/lib/srv/db/common/role/role.go b/lib/srv/db/common/role/role.go index 013cb641381ab..2a552ced30700 100644 --- a/lib/srv/db/common/role/role.go +++ b/lib/srv/db/common/role/role.go @@ -32,35 +32,25 @@ type RoleMatchersConfig struct { DatabaseName string // AutoCreateUser is whether database user will be auto-created. AutoCreateUser bool + // DisableDatabaseNameMatcher skips DatabaseNameMatcher even if the protocol requires it. + DisableDatabaseNameMatcher bool } // GetDatabaseRoleMatchers returns database role matchers for the provided config. func GetDatabaseRoleMatchers(conf RoleMatchersConfig) (matchers services.RoleMatchers) { // For automatic user provisioning, don't check against database users as // users will be connecting as their own Teleport username. - if conf.Database.SupportsAutoUsers() && conf.AutoCreateUser { - if m := databaseNameMatcher(conf.Database.GetProtocol(), conf.DatabaseName); m != nil { - matchers = append(matchers, m) - } - return matchers - } - return DatabaseRoleMatchers(conf.Database, conf.DatabaseUser, conf.DatabaseName) -} - -// DatabaseRoleMatchers returns role matchers based on the database. -// -// DEPRECATED: Prefer to use GetDatabaseRoleMatchers above which supports -// automatic user provisioning and has more flexible config. -func DatabaseRoleMatchers(db types.Database, user, database string) services.RoleMatchers { - roleMatchers := services.RoleMatchers{ - services.NewDatabaseUserMatcher(db, user), + disableDatabaseUserMatcher := conf.Database.SupportsAutoUsers() && conf.AutoCreateUser + if !disableDatabaseUserMatcher { + matchers = append(matchers, services.NewDatabaseUserMatcher(conf.Database, conf.DatabaseUser)) } - if matcher := databaseNameMatcher(db.GetProtocol(), database); matcher != nil { - roleMatchers = append(roleMatchers, matcher) + if !conf.DisableDatabaseNameMatcher { + if matcher := databaseNameMatcher(conf.Database.GetProtocol(), conf.DatabaseName); matcher != nil { + matchers = append(matchers, matcher) + } } - - return roleMatchers + return } // RequireDatabaseUserMatcher returns true if databases with provided protocol diff --git a/lib/srv/db/common/role/role_test.go b/lib/srv/db/common/role/role_test.go new file mode 100644 index 0000000000000..7bc58d5e9abb4 --- /dev/null +++ b/lib/srv/db/common/role/role_test.go @@ -0,0 +1,109 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package role + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" +) + +func TestGetDatabaseRoleMatchers(t *testing.T) { + postgresDatabase, err := types.NewDatabaseV3(types.Metadata{ + Name: "postgres", + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + AdminUser: &types.DatabaseAdminUser{ + Name: "teleport-admin", + }, + }) + require.NoError(t, err) + + mysqlDatabase, err := types.NewDatabaseV3(types.Metadata{ + Name: "mysql", + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + }) + require.NoError(t, err) + + require.NoError(t, err) + tests := []struct { + name string + inputConfig RoleMatchersConfig + expectRoleMatchers services.RoleMatchers + }{ + { + name: "database name matcher required", + inputConfig: RoleMatchersConfig{ + Database: postgresDatabase, + DatabaseUser: "alice", + DatabaseName: "db1", + }, + expectRoleMatchers: services.RoleMatchers{ + services.NewDatabaseUserMatcher(postgresDatabase, "alice"), + &services.DatabaseNameMatcher{Name: "db1"}, + }, + }, + { + name: "database name matcher not required", + inputConfig: RoleMatchersConfig{ + Database: mysqlDatabase, + DatabaseUser: "alice", + DatabaseName: "db1", + }, + expectRoleMatchers: services.RoleMatchers{ + services.NewDatabaseUserMatcher(postgresDatabase, "alice"), + }, + }, + { + name: "AutoCreateUser", + inputConfig: RoleMatchersConfig{ + Database: postgresDatabase, + DatabaseUser: "alice", + DatabaseName: "db1", + AutoCreateUser: true, + }, + expectRoleMatchers: services.RoleMatchers{ + &services.DatabaseNameMatcher{Name: "db1"}, + }, + }, + { + name: "DisableDatabaseNameMatcher", + inputConfig: RoleMatchersConfig{ + Database: postgresDatabase, + DatabaseUser: "alice", + DatabaseName: "db1", + DisableDatabaseNameMatcher: true, + }, + expectRoleMatchers: services.RoleMatchers{ + services.NewDatabaseUserMatcher(postgresDatabase, "alice"), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.EqualValues(t, test.expectRoleMatchers, GetDatabaseRoleMatchers(test.inputConfig)) + }) + } +} diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 4a72ca92daa73..470eea2dc36a1 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -299,11 +299,11 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er } state := sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - sessionCtx.Database, - sessionCtx.DatabaseUser, - sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: sessionCtx.Database, + DatabaseUser: sessionCtx.DatabaseUser, + DatabaseName: sessionCtx.DatabaseName, + }) err = sessionCtx.Checker.CheckAccess( sessionCtx.Database, state, diff --git a/lib/srv/db/elasticsearch/engine.go b/lib/srv/db/elasticsearch/engine.go index a711189ce94a6..3a94b02781c12 100644 --- a/lib/srv/db/elasticsearch/engine.go +++ b/lib/srv/db/elasticsearch/engine.go @@ -260,11 +260,11 @@ func (e *Engine) authorizeConnection(ctx context.Context) error { } state := e.sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + }) err = e.sessionCtx.Checker.CheckAccess( e.sessionCtx.Database, state, diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 9d5c5bd7bd868..02075319ab8c1 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -209,13 +209,18 @@ func (e *Engine) authorizeConnection(ctx context.Context, sessionCtx *common.Ses } state := sessionCtx.GetAccessState(authPref) - // Only the username is checked upon initial connection. MongoDB sends - // database name with each protocol message (for query, update, etc.) - // so it is checked when we receive a message from client. + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: sessionCtx.Database, + DatabaseUser: sessionCtx.DatabaseUser, + // Only the username is checked upon initial connection. MongoDB sends + // database name with each protocol message (for query, update, etc.) so it + // is checked when we receive a message from client. + DisableDatabaseNameMatcher: true, + }) err = sessionCtx.Checker.CheckAccess( sessionCtx.Database, state, - services.NewDatabaseUserMatcher(sessionCtx.Database, sessionCtx.DatabaseUser), + dbRoleMatchers..., ) if err != nil { e.Audit.OnSessionStart(e.Context, sessionCtx, err) @@ -260,13 +265,17 @@ func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol case "authenticate", "saslStart", "saslContinue", "logout": return trace.AccessDenied("access denied") } + // Otherwise authorize the command against allowed databases. - return sessionCtx.Checker.CheckAccess(sessionCtx.Database, + return sessionCtx.Checker.CheckAccess( + sessionCtx.Database, services.AccessState{MFAVerified: true}, - role.DatabaseRoleMatchers( - sessionCtx.Database, - sessionCtx.DatabaseUser, - database)...) + role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: sessionCtx.Database, + DatabaseUser: sessionCtx.DatabaseUser, + DatabaseName: database, + })..., + ) } func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) { diff --git a/lib/srv/db/opensearch/engine.go b/lib/srv/db/opensearch/engine.go index 993e79666fe7a..34c18a5514870 100644 --- a/lib/srv/db/opensearch/engine.go +++ b/lib/srv/db/opensearch/engine.go @@ -359,11 +359,11 @@ func (e *Engine) checkAccess(ctx context.Context) error { } state := e.sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + }) err = e.sessionCtx.Checker.CheckAccess( e.sessionCtx.Database, state, diff --git a/lib/srv/db/redis/client.go b/lib/srv/db/redis/client.go index 91460a197d935..e89a99a2d6375 100644 --- a/lib/srv/db/redis/client.go +++ b/lib/srv/db/redis/client.go @@ -161,11 +161,11 @@ func fetchCredentialsOnConnect(closeCtx context.Context, sessionCtx *common.Sess return func(ctx context.Context, conn *redis.Conn) error { err := sessionCtx.Checker.CheckAccess(sessionCtx.Database, services.AccessState{MFAVerified: true}, - role.DatabaseRoleMatchers( - sessionCtx.Database, - sessionCtx.DatabaseUser, - sessionCtx.DatabaseName, - )...) + role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: sessionCtx.Database, + DatabaseUser: sessionCtx.DatabaseUser, + DatabaseName: sessionCtx.DatabaseName, + })...) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/redis/cmds.go b/lib/srv/db/redis/cmds.go index c2ef323f5d30e..d94421534de32 100644 --- a/lib/srv/db/redis/cmds.go +++ b/lib/srv/db/redis/cmds.go @@ -175,11 +175,11 @@ func (e *Engine) processAuth(ctx context.Context, cmd *redis.Cmd) error { err := e.sessionCtx.Checker.CheckAccess(e.sessionCtx.Database, services.AccessState{MFAVerified: true}, - role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - )...) + role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + })...) if err != nil { return trace.Wrap(err) } @@ -222,11 +222,11 @@ func (e *Engine) processAuth(ctx context.Context, cmd *redis.Cmd) error { err := e.sessionCtx.Checker.CheckAccess(e.sessionCtx.Database, services.AccessState{MFAVerified: true}, - role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - )...) + role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + })...) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/redis/engine.go b/lib/srv/db/redis/engine.go index c40ee08cccf1d..3e44df26cb893 100644 --- a/lib/srv/db/redis/engine.go +++ b/lib/srv/db/redis/engine.go @@ -99,11 +99,11 @@ func (e *Engine) authorizeConnection(ctx context.Context) error { } state := e.sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + }) err = e.sessionCtx.Checker.CheckAccess( e.sessionCtx.Database, state, diff --git a/lib/srv/db/snowflake/engine.go b/lib/srv/db/snowflake/engine.go index b695b2f6471ad..6a0d5fed78d74 100644 --- a/lib/srv/db/snowflake/engine.go +++ b/lib/srv/db/snowflake/engine.go @@ -358,11 +358,11 @@ func (e *Engine) authorizeConnection(ctx context.Context) error { return trace.Wrap(err) } state := e.sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - e.sessionCtx.Database, - e.sessionCtx.DatabaseUser, - e.sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: e.sessionCtx.Database, + DatabaseUser: e.sessionCtx.DatabaseUser, + DatabaseName: e.sessionCtx.DatabaseName, + }) err = e.sessionCtx.Checker.CheckAccess( e.sessionCtx.Database, state,