diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go index 70298a9c3f0..1acd04764df 100644 --- a/server/handles/ssologin.go +++ b/server/handles/ssologin.go @@ -36,14 +36,21 @@ var opts = totp.ValidateOpts{ Algorithm: otp.AlgorithmSHA1, } +func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string { + if useCompatibility { + return common.GetApiUrl(c.Request) + "/api/auth/" + method + } else { + return common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method + } +} + func SSOLoginRedirect(c *gin.Context) { method := c.Query("method") - usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) + useCompatibility := setting.GetBool(conf.SSOCompatibilityMode) enabled := setting.GetBool(conf.SSOLoginEnabled) clientId := setting.GetStr(conf.SSOClientId) platform := setting.GetStr(conf.SSOLoginPlatform) - var r_url string - var redirect_uri string + var rUrl string if !enabled { common.ErrorStrResp(c, "Single sign-on is not enabled", 403) return @@ -53,37 +60,33 @@ func SSOLoginRedirect(c *gin.Context) { common.ErrorStrResp(c, "no method provided", 400) return } - if usecompatibility { - redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + method - } else { - redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method - } + redirectUri := ssoRedirectUri(c, useCompatibility, method) urlValues.Add("response_type", "code") - urlValues.Add("redirect_uri", redirect_uri) + urlValues.Add("redirect_uri", redirectUri) urlValues.Add("client_id", clientId) switch platform { case "Github": - r_url = "https://github.com/login/oauth/authorize?" + rUrl = "https://github.com/login/oauth/authorize?" urlValues.Add("scope", "read:user") case "Microsoft": - r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" + rUrl = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" urlValues.Add("scope", "user.read") urlValues.Add("response_mode", "query") case "Google": - r_url = "https://accounts.google.com/o/oauth2/v2/auth?" + rUrl = "https://accounts.google.com/o/oauth2/v2/auth?" urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile") case "Dingtalk": - r_url = "https://login.dingtalk.com/oauth2/auth?" + rUrl = "https://login.dingtalk.com/oauth2/auth?" urlValues.Add("scope", "openid") urlValues.Add("prompt", "consent") urlValues.Add("response_type", "code") case "Casdoor": endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") - r_url = endpoint + "/login/oauth/authorize?" + rUrl = endpoint + "/login/oauth/authorize?" urlValues.Add("scope", "profile") urlValues.Add("state", endpoint) case "OIDC": - oauth2Config, err := GetOIDCClient(c) + oauth2Config, err := GetOIDCClient(c, useCompatibility, redirectUri, method) if err != nil { common.ErrorStrResp(c, err.Error(), 400) return @@ -100,22 +103,14 @@ func SSOLoginRedirect(c *gin.Context) { common.ErrorStrResp(c, "invalid platform", 400) return } - c.Redirect(302, r_url+urlValues.Encode()) + c.Redirect(302, rUrl+urlValues.Encode()) } var ssoClient = resty.New().SetRetryCount(3) -func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) { - var redirect_uri string - usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) - argument := c.Query("method") - if usecompatibility { - argument = path.Base(c.Request.URL.Path) - } - if usecompatibility { - redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument - } else { - redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument +func GetOIDCClient(c *gin.Context, useCompatibility bool, redirectUri, method string) (*oauth2.Config, error) { + if redirectUri == "" { + redirectUri = ssoRedirectUri(c, useCompatibility, method) } endpoint := setting.GetStr(conf.SSOEndpointName) provider, err := oidc.NewProvider(c, endpoint) @@ -127,7 +122,7 @@ func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) { return &oauth2.Config{ ClientID: clientId, ClientSecret: clientSecret, - RedirectURL: redirect_uri, + RedirectURL: redirectUri, // Discovery returns the OAuth2 endpoints. Endpoint: provider.Endpoint(), @@ -181,9 +176,9 @@ func parseJWT(p string) ([]byte, error) { func OIDCLoginCallback(c *gin.Context) { useCompatibility := setting.GetBool(conf.SSOCompatibilityMode) - argument := c.Query("method") + method := c.Query("method") if useCompatibility { - argument = path.Base(c.Request.URL.Path) + method = path.Base(c.Request.URL.Path) } clientId := setting.GetStr(conf.SSOClientId) endpoint := setting.GetStr(conf.SSOEndpointName) @@ -192,7 +187,7 @@ func OIDCLoginCallback(c *gin.Context) { common.ErrorResp(c, err, 400) return } - oauth2Config, err := GetOIDCClient(c) + oauth2Config, err := GetOIDCClient(c, useCompatibility, "", method) if err != nil { common.ErrorResp(c, err, 400) return @@ -236,7 +231,7 @@ func OIDCLoginCallback(c *gin.Context) { common.ErrorStrResp(c, "cannot get username from OIDC provider", 400) return } - if argument == "get_sso_id" { + if method == "get_sso_id" { if useCompatibility { c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) return @@ -252,7 +247,7 @@ func OIDCLoginCallback(c *gin.Context) { c.Data(200, "text/html; charset=utf-8", []byte(html)) return } - if argument == "sso_get_token" { + if method == "sso_get_token" { user, err := db.GetUserBySSOID(userID) if err != nil { user, err = autoRegister(userID, userID, err)