diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index fe5b2529682..34fcd7c884b 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -213,11 +213,12 @@ func (a *AuthServerStatic) ValidateHash(salt []byte, user string, authResponse [ if matchSourceHost(remoteAddr, entry.SourceHost) && isPass { return &StaticUserData{entry.UserData}, nil } - } - computedAuthResponse := ScramblePassword(salt, []byte(entry.Password)) - // Validate the password. - if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Compare(authResponse, computedAuthResponse) == 0 { - return &StaticUserData{entry.UserData}, nil + } else { + computedAuthResponse := ScramblePassword(salt, []byte(entry.Password)) + // Validate the password. + if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Compare(authResponse, computedAuthResponse) == 0 { + return &StaticUserData{entry.UserData}, nil + } } } return &StaticUserData{""}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index d44b2e7eb82..a1fd2809d11 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -123,3 +123,72 @@ func hupTest(t *testing.T, tmpFile *os.File, oldStr, newStr string) { t.Fatalf("%s's Password should be '%s'", newStr, newStr) } } + +func TestStaticPasswords(t *testing.T) { + jsonConfig := ` +{ + "user01": [{ "Password": "user01" }], + "user02": [{ + "MysqlNativePassword": "*B3AD996B12F211BEA47A7C666CC136FB26DC96AF" + }], + "user03": [{ + "MysqlNativePassword": "*211E0153B172BAED4352D5E4628BD76731AF83E7", + "Password": "invalid" + }], + "user04": [ + { "MysqlNativePassword": "*668425423DB5193AF921380129F465A6425216D0" }, + { "Password": "password2" } + ] +}` + + tests := []struct { + user string + password string + success bool + }{ + {"user01", "user01", true}, + {"user01", "password", false}, + {"user01", "", false}, + {"user02", "user02", true}, + {"user02", "password", false}, + {"user02", "", false}, + {"user03", "user03", true}, + {"user03", "password", false}, + {"user03", "invalid", false}, + {"user03", "", false}, + {"user04", "password1", true}, + {"user04", "password2", true}, + {"user04", "", false}, + {"userXX", "", false}, + {"userXX", "", false}, + {"", "", false}, + {"", "password", false}, + } + + auth := NewAuthServerStatic() + auth.loadConfigFromParams("", jsonConfig) + ip := net.ParseIP("127.0.0.1") + addr := &net.IPAddr{IP: ip, Zone: ""} + + for _, c := range tests { + t.Run(fmt.Sprintf("%s-%s", c.user, c.password), func(t *testing.T) { + salt, err := NewSalt() + if err != nil { + t.Fatalf("error generating salt: %v", err) + } + + scrambled := ScramblePassword(salt, []byte(c.password)) + _, err = auth.ValidateHash(salt, c.user, scrambled, addr) + + if c.success { + if err != nil { + t.Fatalf("authentication should have succeeded: %v", err) + } + } else { + if err == nil { + t.Fatalf("authentication should have failed") + } + } + }) + } +}