Skip to content

Commit

Permalink
change arg type for WithCustomAuthPlugins
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhan1 committed Jun 26, 2024
1 parent be73f44 commit 2ca83b7
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 47 deletions.
10 changes: 5 additions & 5 deletions pkg/extension/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ type VerifyDynamicPrivRequest struct {
func validateAuthPlugin(m *Manifest) error {
pluginNames := make(map[string]bool)
// Validate required functions for the auth plugins
for pluginName, p := range m.authPlugins {
for _, p := range m.authPlugins {
if p.Name == "" {
return errors.Errorf("auth plugin name cannot be empty for %s", pluginName)
return errors.Errorf("auth plugin name cannot be empty for %s", p.Name)
}
if pluginNames[p.Name] {
return errors.Errorf("auth plugin name %s has already been registered", p.Name)
Expand All @@ -116,13 +116,13 @@ func validateAuthPlugin(m *Manifest) error {
return errors.Errorf("auth plugin name %s is a reserved name for default auth plugins", p.Name)
}
if p.AuthenticateUser == nil {
return errors.Errorf("auth plugin AuthenticateUser function cannot be nil for %s", pluginName)
return errors.Errorf("auth plugin AuthenticateUser function cannot be nil for %s", p.Name)
}
if p.GenerateAuthString == nil {
return errors.Errorf("auth plugin GenerateAuthString function cannot be nil for %s", pluginName)
return errors.Errorf("auth plugin GenerateAuthString function cannot be nil for %s", p.Name)
}
if p.ValidateAuthString == nil {
return errors.Errorf("auth plugin ValidateAuthString function cannot be nil for %s", pluginName)
return errors.Errorf("auth plugin ValidateAuthString function cannot be nil for %s", p.Name)
}
}
return nil
Expand Down
36 changes: 15 additions & 21 deletions pkg/extension/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/pingcap/tidb/pkg/testkit"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)

type MockAuthPlugin struct {
Expand Down Expand Up @@ -106,7 +105,6 @@ func TestAuthPlugin(t *testing.T) {
defer extension.Reset()
extension.Reset()

authChecks := map[string]*extension.AuthPlugin{}
p := new(MockAuthPlugin)
p.On("Name").Return("authentication_test_plugin")

Expand Down Expand Up @@ -161,15 +159,15 @@ func TestAuthPlugin(t *testing.T) {
})
p.On("VerifyDynamicPrivilege", sysVarAdminMatcher).Return(false)

authChecks[p.Name()] = &extension.AuthPlugin{
authChecks := []*extension.AuthPlugin{{
Name: p.Name(),
AuthenticateUser: p.AuthenticateUser,
ValidateAuthString: p.ValidateAuthString,
GenerateAuthString: p.GenerateAuthString,
VerifyPrivilege: p.VerifyPrivilege,
VerifyDynamicPrivilege: p.VerifyDynamicPrivilege,
RequiredClientSidePlugin: mysql.AuthNativePassword,
}
}}

require.NoError(t, extension.Register(
"extension_authentication_plugin",
Expand All @@ -180,7 +178,7 @@ func TestAuthPlugin(t *testing.T) {
Name: "extension_authentication_plugin",
Value: mysql.AuthNativePassword,
Type: variable.TypeEnum,
PossibleValues: maps.Keys(authChecks),
PossibleValues: []string{p.Name()},
},
}),
extension.WithBootstrap(func(_ extension.BootstrapContext) error {
Expand Down Expand Up @@ -289,7 +287,6 @@ func TestAuthPluginSwitchPlugins(t *testing.T) {
defer extension.Reset()
extension.Reset()

authChecks := map[string]*extension.AuthPlugin{}
p := new(MockAuthPlugin)
p.On("Name").Return("authentication_test_plugin")
authnMatcher1 := mock.MatchedBy(func(ctx extension.AuthenticateRequest) bool {
Expand Down Expand Up @@ -317,15 +314,15 @@ func TestAuthPluginSwitchPlugins(t *testing.T) {
})
p.On("VerifyPrivilege", insertMatcher).Return(false)

authChecks[p.Name()] = &extension.AuthPlugin{
authChecks := []*extension.AuthPlugin{{
Name: p.Name(),
AuthenticateUser: p.AuthenticateUser,
ValidateAuthString: p.ValidateAuthString,
GenerateAuthString: p.GenerateAuthString,
VerifyPrivilege: p.VerifyPrivilege,
VerifyDynamicPrivilege: p.VerifyDynamicPrivilege,
RequiredClientSidePlugin: mysql.AuthNativePassword,
}
}}

require.NoError(t, extension.Register(
"extension_authentication_plugin",
Expand All @@ -336,7 +333,7 @@ func TestAuthPluginSwitchPlugins(t *testing.T) {
Name: "extension_authentication_plugin",
Value: mysql.AuthNativePassword,
Type: variable.TypeEnum,
PossibleValues: maps.Keys(authChecks),
PossibleValues: []string{p.Name()},
},
}),
extension.WithBootstrap(func(_ extension.BootstrapContext) error {
Expand Down Expand Up @@ -423,19 +420,18 @@ func TestCreateUserWhenGrant(t *testing.T) {
defer extension.Reset()
extension.Reset()

authChecks := map[string]*extension.AuthPlugin{}
p := new(MockAuthPlugin)
p.On("Name").Return("authentication_test_plugin")
p.On("ValidateAuthString", mock.Anything).Return(true)
p.On("GenerateAuthString", "xxx").Return("encodedpassword", true)

authChecks[p.Name()] = &extension.AuthPlugin{
authChecks := []*extension.AuthPlugin{{
Name: p.Name(),
AuthenticateUser: p.AuthenticateUser,
ValidateAuthString: p.ValidateAuthString,
GenerateAuthString: p.GenerateAuthString,
RequiredClientSidePlugin: mysql.AuthNativePassword,
}
}}

require.NoError(t, extension.Register(
"extension_authentication_plugin",
Expand All @@ -446,7 +442,7 @@ func TestCreateUserWhenGrant(t *testing.T) {
Name: "extension_authentication_plugin",
Value: mysql.AuthNativePassword,
Type: variable.TypeEnum,
PossibleValues: maps.Keys(authChecks),
PossibleValues: []string{p.Name()},
},
}),
extension.WithBootstrap(func(_ extension.BootstrapContext) error {
Expand Down Expand Up @@ -482,7 +478,6 @@ func TestCreateViewWithPluginUser(t *testing.T) {
defer extension.Reset()
extension.Reset()

authChecks := map[string]*extension.AuthPlugin{}
p := new(MockAuthPlugin)
p.On("Name").Return("authentication_test_plugin")
authnMatcher1 := mock.MatchedBy(func(ctx extension.AuthenticateRequest) bool {
Expand All @@ -502,15 +497,15 @@ func TestCreateViewWithPluginUser(t *testing.T) {
})
p.On("VerifyPrivilege", createViewMatcher).Return(true)

authChecks[p.Name()] = &extension.AuthPlugin{
authChecks := []*extension.AuthPlugin{{
Name: p.Name(),
AuthenticateUser: p.AuthenticateUser,
ValidateAuthString: p.ValidateAuthString,
GenerateAuthString: p.GenerateAuthString,
VerifyPrivilege: p.VerifyPrivilege,
VerifyDynamicPrivilege: p.VerifyDynamicPrivilege,
RequiredClientSidePlugin: mysql.AuthNativePassword,
}
}}

require.NoError(t, extension.Register(
"extension_authentication_plugin",
Expand All @@ -521,7 +516,7 @@ func TestCreateViewWithPluginUser(t *testing.T) {
Name: "extension_authentication_plugin",
Value: mysql.AuthNativePassword,
Type: variable.TypeEnum,
PossibleValues: maps.Keys(authChecks),
PossibleValues: []string{p.Name()},
},
}),
))
Expand Down Expand Up @@ -587,7 +582,6 @@ func TestPluginUserModification(t *testing.T) {
defer extension.Reset()
extension.Reset()

authChecks := map[string]*extension.AuthPlugin{}
p := new(MockAuthPlugin)
p.On("Name").Return("authentication_test_plugin")
authnMatcher1 := mock.MatchedBy(func(ctx extension.AuthenticateRequest) bool {
Expand All @@ -600,15 +594,15 @@ func TestPluginUserModification(t *testing.T) {
p.On("VerifyPrivilege", mock.Anything).Return(true)
p.On("VerifyDynamicPrivilege", mock.Anything).Return(true)

authChecks[p.Name()] = &extension.AuthPlugin{
authChecks := []*extension.AuthPlugin{{
Name: p.Name(),
AuthenticateUser: p.AuthenticateUser,
ValidateAuthString: p.ValidateAuthString,
GenerateAuthString: p.GenerateAuthString,
VerifyPrivilege: p.VerifyPrivilege,
VerifyDynamicPrivilege: p.VerifyDynamicPrivilege,
RequiredClientSidePlugin: mysql.AuthNativePassword,
}
}}

require.NoError(t, extension.Register(
"extension_authentication_plugin",
Expand All @@ -619,7 +613,7 @@ func TestPluginUserModification(t *testing.T) {
Name: "extension_authentication_plugin",
Value: mysql.AuthNativePassword,
Type: variable.TypeEnum,
PossibleValues: maps.Keys(authChecks),
PossibleValues: []string{p.Name()},
},
}),
))
Expand Down
4 changes: 2 additions & 2 deletions pkg/extension/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func (es *Extensions) GetAuthPlugins() map[string]*AuthPlugin {
authPlugins := make(map[string]*AuthPlugin)
for _, m := range es.manifests {
if m.authPlugins != nil {
for pluginName, p := range m.authPlugins {
authPlugins[pluginName] = p
for _, p := range m.authPlugins {
authPlugins[p.Name] = p
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/extension/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func WithCustomAccessCheck(fn AccessCheckFunc) Option {
}

// WithCustomAuthPlugins specifies the custom authentication plugins available for the system.
func WithCustomAuthPlugins(authPlugins map[string]*AuthPlugin) Option {
func WithCustomAuthPlugins(authPlugins []*AuthPlugin) Option {
return func(m *Manifest) {
m.authPlugins = authPlugins
}
Expand Down Expand Up @@ -125,7 +125,7 @@ type Manifest struct {
bootstrap func(BootstrapContext) error
funcs []*FunctionDef
accessCheckFunc AccessCheckFunc
authPlugins map[string]*AuthPlugin
authPlugins []*AuthPlugin
sessionHandlerFactory func() *SessionHandler
close func()
}
Expand Down
30 changes: 15 additions & 15 deletions pkg/extension/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,15 @@ func TestAuthPluginValidation(t *testing.T) {
defer extension.Reset()
extension.Reset()

require.NoError(t, extension.Register("test", extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {Name: ""},
require.NoError(t, extension.Register("test", extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{Name: ""},
})))
require.ErrorContains(t, extension.Setup(), "auth plugin name cannot be empty")

extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
ValidateAuthString: func(pwdHash string) bool {
return false
Expand All @@ -375,8 +375,8 @@ func TestAuthPluginValidation(t *testing.T) {

extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
Expand All @@ -391,8 +391,8 @@ func TestAuthPluginValidation(t *testing.T) {

extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
Expand All @@ -407,8 +407,8 @@ func TestAuthPluginValidation(t *testing.T) {

extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
Expand All @@ -420,7 +420,7 @@ func TestAuthPluginValidation(t *testing.T) {
return true
},
},
"plugin2": {
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
Expand All @@ -438,8 +438,8 @@ func TestAuthPluginValidation(t *testing.T) {

extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "mysql_native_password",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
Expand All @@ -457,8 +457,8 @@ func TestAuthPluginValidation(t *testing.T) {

extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins(map[string]*extension.AuthPlugin{
"plugin1": {
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/extension/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func newSessionExtensions(es *Extensions) *SessionExtensions {
}
if m.authPlugins != nil {
connExtensions.authPlugins = make(map[string]*AuthPlugin)
for k, v := range m.authPlugins {
connExtensions.authPlugins[k] = v
for _, p := range m.authPlugins {
connExtensions.authPlugins[p.Name] = p
}
}
}
Expand Down

0 comments on commit 2ca83b7

Please sign in to comment.