Skip to content

Commit

Permalink
*: implement auth plugin support in the extension framework (#53494)
Browse files Browse the repository at this point in the history
close #53181
  • Loading branch information
yzhan1 authored Jul 3, 2024
1 parent c71eece commit 3860ba5
Show file tree
Hide file tree
Showing 16 changed files with 1,210 additions and 17 deletions.
2 changes: 2 additions & 0 deletions pkg/executor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ go_library(
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/expression/context",
"//pkg/extension",
"//pkg/infoschema",
"//pkg/infoschema/context",
"//pkg/keyspace",
Expand Down Expand Up @@ -388,6 +389,7 @@ go_test(
"//pkg/executor/sortexec",
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/extension",
"//pkg/infoschema",
"//pkg/kv",
"//pkg/meta",
Expand Down
9 changes: 5 additions & 4 deletions pkg/executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,15 @@ func (e *GrantExec) Next(ctx context.Context, _ *chunk.Chunk) error {
// It is required for compatibility with 5.7 but removed from 8.0
// since it results in a massive security issue:
// spelling errors will create users with no passwords.
pwd, ok := user.EncodedPassword()
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
authPlugin := mysql.AuthNativePassword
if user.AuthOpt != nil && user.AuthOpt.AuthPlugin != "" {
authPlugin = user.AuthOpt.AuthPlugin
}
authPluginImpl, _ := e.Ctx().GetExtensions().GetAuthPlugin(authPlugin)
pwd, ok := encodePassword(user, authPluginImpl)
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
_, err := internalSession.GetSQLExecutor().ExecuteInternal(internalCtx,
`INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`,
mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, authPlugin)
Expand Down
38 changes: 29 additions & 9 deletions pkg/executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/pingcap/tidb/pkg/executor/internal/querywatch"
executor_metrics "github.com/pingcap/tidb/pkg/executor/metrics"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
Expand Down Expand Up @@ -1164,16 +1165,21 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm
return err
}
}
pwd, ok := spec.EncodedPassword()

if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
var pluginImpl *extension.AuthPlugin

switch authPlugin {
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthTiDBAuthToken, mysql.AuthLDAPSimple, mysql.AuthLDAPSASL:
default:
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
found := false
// If the plugin is not a registered extension auth plugin, return error
if pluginImpl, found = e.Ctx().GetExtensions().GetAuthPlugin(authPlugin); !found {
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
}
}

pwd, ok := encodePassword(spec, pluginImpl)
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}

recordTokenIssuer := tokenIssuer
Expand Down Expand Up @@ -1607,6 +1613,10 @@ func checkPasswordReusePolicy(ctx context.Context, sqlExecutor sqlexec.SQLExecut
// and the Password Reuse Policy does not take effect.
return nil
}
// Skip password reuse checks for extension auth plugins
if _, ok := sctx.GetExtensions().GetAuthPlugin(authPlugin); ok {
return nil
}
// read password reuse info from mysql.user and global variables.
passwdReuseInfo, err := getUserPasswordLimit(ctx, sqlExecutor, userDetail.user, userDetail.host, userDetail.pLI)
if err != nil {
Expand Down Expand Up @@ -1787,6 +1797,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
if spec.AuthOpt.AuthPlugin == "" {
spec.AuthOpt.AuthPlugin = currentAuthPlugin
}
var authPluginImpl *extension.AuthPlugin
switch spec.AuthOpt.AuthPlugin {
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthLDAPSimple, mysql.AuthLDAPSASL, "":
authTokenOptionHandler = noNeedAuthTokenOptions
Expand All @@ -1795,7 +1806,10 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
authTokenOptionHandler = RequireAuthTokenOptions
}
default:
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
found := false
if authPluginImpl, found = e.Ctx().GetExtensions().GetAuthPlugin(spec.AuthOpt.AuthPlugin); !found {
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
}
}
// changing the auth method prunes history.
if spec.AuthOpt.AuthPlugin != currentAuthPlugin {
Expand All @@ -1813,7 +1827,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
return err
}
}
pwd, ok := spec.EncodedPassword()
pwd, ok := encodePassword(spec, authPluginImpl)
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
Expand Down Expand Up @@ -2480,7 +2494,13 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error
e.Ctx().GetSessionVars().StmtCtx.AppendNote(exeerrors.ErrSetPasswordAuthPlugin.FastGenByArgs(u, h))
pwd = ""
default:
pwd = auth.EncodePassword(s.Password)
if pluginImpl, ok := e.Ctx().GetExtensions().GetAuthPlugin(authplugin); ok {
if pwd, ok = pluginImpl.GenerateAuthString(s.Password); !ok {
return exeerrors.ErrPasswordFormat.GenWithStackByArgs()
}
} else {
pwd = auth.EncodePassword(s.Password)
}
}

// for Support Password Reuse Policy.
Expand Down
22 changes: 22 additions & 0 deletions pkg/executor/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ package executor

import (
"strings"

"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/parser/ast"
)

var (
Expand Down Expand Up @@ -97,3 +100,22 @@ func (b *batchRetrieverHelper) nextBatch(retrieveRange func(start, end int) erro
}
return nil
}

// encodePassword encodes the password for the user. It invokes the auth plugin if it is available.
func encodePassword(u *ast.UserSpec, authPlugin *extension.AuthPlugin) (string, bool) {
if u.AuthOpt == nil {
return "", true
}
// If the extension auth plugin is available, use it to encode the password.
if authPlugin != nil {
if u.AuthOpt.ByAuthString {
return authPlugin.GenerateAuthString(u.AuthOpt.AuthString)
}
// If we receive a hash string, validate it first.
if authPlugin.ValidateAuthString(u.AuthOpt.HashString) {
return u.AuthOpt.HashString, true
}
return "", false
}
return u.EncodedPassword()
}
44 changes: 44 additions & 0 deletions pkg/executor/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/planner/core"
"github.com/pingcap/tidb/pkg/types"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -133,3 +136,44 @@ func TestEqualDatumsAsBinary(t *testing.T) {
require.Equal(t, tt.same, res)
}
}

func TestEncodePasswordWithPlugin(t *testing.T) {
hashString := "*3D56A309CD04FA2EEF181462E59011F075C89548"
u := &ast.UserSpec{
User: &auth.UserIdentity{
Username: "test",
},
AuthOpt: &ast.AuthOption{
ByAuthString: false,
AuthString: "xxx",
HashString: hashString,
},
}

p := &extension.AuthPlugin{
ValidateAuthString: func(s string) bool {
return false
},
GenerateAuthString: func(s string) (string, bool) {
if s == "xxx" {
return "xxxxxxx", true
}
return "", false
},
}

u.AuthOpt.ByAuthString = false
_, ok := encodePassword(u, p)
require.False(t, ok)

u.AuthOpt.AuthString = "xxx"
u.AuthOpt.ByAuthString = true
pwd, ok := encodePassword(u, p)
require.True(t, ok)
require.Equal(t, "xxxxxxx", pwd)

u.AuthOpt = nil
pwd, ok = encodePassword(u, p)
require.True(t, ok)
require.Equal(t, "", pwd)
}
6 changes: 5 additions & 1 deletion pkg/extension/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "extension",
srcs = [
"auth.go",
"extensions.go",
"function.go",
"manifest.go",
Expand All @@ -17,6 +18,7 @@ go_library(
"//pkg/parser/ast",
"//pkg/parser/auth",
"//pkg/parser/mysql",
"//pkg/privilege/conn",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/types",
Expand All @@ -31,6 +33,7 @@ go_test(
name = "extension_test",
timeout = "short",
srcs = [
"auth_test.go",
"bootstrap_test.go",
"event_listener_test.go",
"function_test.go",
Expand All @@ -39,7 +42,7 @@ go_test(
],
embed = [":extension"],
flaky = True,
shard_count = 15,
shard_count = 21,
deps = [
"//pkg/expression",
"//pkg/parser/ast",
Expand All @@ -59,6 +62,7 @@ go_test(
"//pkg/util/mock",
"//pkg/util/sem",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//mock",
"@com_github_stretchr_testify//require",
"@org_uber_go_goleak//:goleak",
],
Expand Down
143 changes: 143 additions & 0 deletions pkg/extension/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package extension

import (
"crypto/tls"
"slices"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/privilege/conn"
)

// AuthPlugin contains attributes needed for an authentication plugin.
type AuthPlugin struct {
// Name is the name of the auth plugin. It will be registered as a system variable in TiDB which can be used inside the `CREATE USER ... IDENTIFIED WITH 'plugin_name'` statement.
Name string

// RequiredClientSidePlugin is the name of the client-side plugin required by the server-side plugin. It will be used to check if the client has the required plugin installed and require the client to use it if installed.
// The user can require default MySQL plugins such as 'caching_sha2_password' or 'mysql_native_password'.
// If this is empty then `AuthPlugin.Name` is used as the required client-side plugin.
RequiredClientSidePlugin string

// AuthenticateUser is called when a client connects to the server as a user and the server authenticates the user.
// If an error is returned, the login attempt fails, otherwise it succeeds.
// request: The request context for the authentication plugin to authenticate a user
AuthenticateUser func(request AuthenticateRequest) error

// GenerateAuthString is a function for user to implement customized ways to encode the password (e.g. hash/salt/clear-text). The returned string will be stored as the encoded password in the mysql.user table.
// If the input password is considered as invalid, this should return an error.
// pwd: User's input password in CREATE/ALTER USER statements in clear-text
GenerateAuthString func(pwd string) (string, bool)

// ValidateAuthString checks if the password hash stored in the mysql.user table or passed in from `IDENTIFIED AS` is valid.
// This is called when retrieving an existing user to make sure the password stored is valid and not modified and make sure user is passing a valid password hash in `IDENTIFIED AS`.
// pwdHash: hash of the password stored in the internal user table
ValidateAuthString func(pwdHash string) bool

// VerifyPrivilege is called for each user queries, and serves as an extra check for privileges for the user.
// It will only be executed if the user has already been granted the privilege in SQL layer.
// Returns true if user has the requested privilege.
// request: The request context for the authorization plugin to authorize a user's static privilege
VerifyPrivilege func(request VerifyStaticPrivRequest) bool

// VerifyDynamicPrivilege is called for each user queries, and serves as an extra check for dynamic privileges for the user.
// It will only be executed if the user has already been granted the dynamic privilege in SQL layer.
// Returns true if user has the requested privilege.
// request: The request context for the authorization plugin to authorize a user's dynamic privilege
VerifyDynamicPrivilege func(request VerifyDynamicPrivRequest) bool
}

// AuthenticateRequest contains the context for the authentication plugin to authenticate a user.
type AuthenticateRequest struct {
// User The username in the connect attempt
User string
// StoredAuthString The user's auth string stored in mysql.user table
StoredAuthString string
// InputAuthString The user's auth string passed in from the connection attempt in bytes
InputAuthString []byte
// Salt Randomly generated salt for the current connection
Salt []byte
// ConnState The TLS connection state (contains the TLS certificate) if client is using TLS. It will be nil if the client is not using TLS
ConnState *tls.ConnectionState
// AuthConn Interface for the plugin to communicate with the client
AuthConn conn.AuthConn
}

// VerifyStaticPrivRequest contains the context for the plugin to authorize a user's static privilege.
type VerifyStaticPrivRequest struct {
// User The username in the connect attempt
User string
// Host The host that the user is connecting from
Host string
// DB The database to check for privilege
DB string
// Table The table to check for privilege
Table string
// Column The column to check for privilege (currently just a placeholder in TiDB as column-level privilege is not supported by TiDB yet)
Column string
// StaticPriv The privilege type of the SQL statement that will be executed
StaticPriv mysql.PrivilegeType
// ConnState The TLS connection state (contains the TLS certificate) if client is using TLS. It will be nil if the client is not using TLS
ConnState *tls.ConnectionState
// ActiveRoles List of active MySQL roles for the current user
ActiveRoles []*auth.RoleIdentity
}

// VerifyDynamicPrivRequest contains the context for the plugin to authorize a user's dynamic privilege.
type VerifyDynamicPrivRequest struct {
// User The username in the connect attempt
User string
// Host The host that the user is connecting from
Host string
// DynamicPriv the dynamic privilege required by the user's SQL statement
DynamicPriv string
// ConnState The TLS connection state (contains the TLS certificate) if client is using TLS. It will be nil if the client is not using TLS
ConnState *tls.ConnectionState
// ActiveRoles List of active MySQL roles for the current user
ActiveRoles []*auth.RoleIdentity
// WithGrant Whether the statement to be executed is granting the user privilege for executing GRANT statements
WithGrant bool
}

// validateAuthPlugin validates the auth plugin functions and attributes.
func validateAuthPlugin(m *Manifest) error {
pluginNames := make(map[string]bool)
// Validate required functions for the auth plugins
for _, p := range m.authPlugins {
if p.Name == "" {
return errors.Errorf("auth plugin name cannot be empty for %s", p.Name)
}
if pluginNames[p.Name] {
return errors.Errorf("auth plugin name %s has already been registered", p.Name)
}
pluginNames[p.Name] = true
if slices.Contains(mysql.DefaultAuthPlugins, p.Name) {
return errors.Errorf("auth plugin name %s is a reserved name for default auth plugins", p.Name)
}
if p.AuthenticateUser == nil {
return errors.Errorf("auth plugin AuthenticateUser function cannot be nil for %s", p.Name)
}
if p.GenerateAuthString == nil {
return errors.Errorf("auth plugin GenerateAuthString function cannot be nil for %s", p.Name)
}
if p.ValidateAuthString == nil {
return errors.Errorf("auth plugin ValidateAuthString function cannot be nil for %s", p.Name)
}
}
return nil
}
Loading

0 comments on commit 3860ba5

Please sign in to comment.