From 603681fbe6bb2620daaddc49a10019e1c5a25d71 Mon Sep 17 00:00:00 2001 From: itsHenry <2671230065@qq.com> Date: Thu, 2 Mar 2023 17:55:33 +0800 Subject: [PATCH] feat: rebuild Single sign-on system (#3649 close #3571) * rebuild single sign on system * perf: use cache * fix: codefactor check --------- Co-authored-by: Andy Hsu --- internal/bootstrap/data/setting.go | 9 +- internal/conf/const.go | 85 +++++++------- internal/db/user.go | 6 +- internal/model/setting.go | 2 +- internal/model/user.go | 2 +- server/handles/auth.go | 2 +- server/handles/githublogin.go | 101 ---------------- server/handles/ssologin.go | 181 +++++++++++++++++++++++++++++ server/router.go | 4 +- 9 files changed, 237 insertions(+), 155 deletions(-) delete mode 100644 server/handles/githublogin.go create mode 100644 server/handles/ssologin.go diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index 1d0bf8b4e04..aed360e34b8 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -150,10 +150,11 @@ func InitialSettings() []model.SettingItem { {Key: conf.MaxIndexDepth, Value: "20", Type: conf.TypeNumber, Group: model.INDEX, Flag: model.PRIVATE, Help: `max depth of index`}, {Key: conf.IndexProgress, Value: "{}", Type: conf.TypeText, Group: model.SINGLE, Flag: model.PRIVATE}, - // GitHub settings - {Key: conf.GithubClientId, Value: "", Type: conf.TypeString, Group: model.GITHUB, Flag: model.PRIVATE}, - {Key: conf.GithubClientSecrets, Value: "", Type: conf.TypeString, Group: model.GITHUB, Flag: model.PRIVATE}, - {Key: conf.GithubLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GITHUB, Flag: model.PUBLIC}, + // SSO settings + {Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC}, + {Key: conf.SSOLoginplatform, Type: conf.TypeSelect, Options: "Github,Microsoft,Google,Dingtalk", Group: model.SSO, Flag: model.PUBLIC}, // qbittorrent settings {Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, diff --git a/internal/conf/const.go b/internal/conf/const.go index 74fcddcb6ed..53fef61b1f0 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -1,33 +1,33 @@ package conf const ( - TypeString = "string" - TypeSelect = "select" - TypeBool = "bool" - TypeText = "text" - TypeNumber = "number" + TypeString = "string" + TypeSelect = "select" + TypeBool = "bool" + TypeText = "text" + TypeNumber = "number" ) const ( - // site - VERSION = "version" - SiteTitle = "site_title" - Announcement = "announcement" - AllowIndexed = "allow_indexed" + // site + VERSION = "version" + SiteTitle = "site_title" + Announcement = "announcement" + AllowIndexed = "allow_indexed" - Logo = "logo" - Favicon = "favicon" - MainColor = "main_color" + Logo = "logo" + Favicon = "favicon" + MainColor = "main_color" - // preview - TextTypes = "text_types" - AudioTypes = "audio_types" - VideoTypes = "video_types" - ImageTypes = "image_types" - ProxyTypes = "proxy_types" - ProxyIgnoreHeaders = "proxy_ignore_headers" - AudioAutoplay = "audio_autoplay" - VideoAutoplay = "video_autoplay" + // preview + TextTypes = "text_types" + AudioTypes = "audio_types" + VideoTypes = "video_types" + ImageTypes = "image_types" + ProxyTypes = "proxy_types" + ProxyIgnoreHeaders = "proxy_ignore_headers" + AudioAutoplay = "audio_autoplay" + VideoAutoplay = "video_autoplay" // global HideFiles = "hide_files" @@ -39,36 +39,37 @@ const ( OcrApi = "ocr_api" FilenameCharMapping = "filename_char_mapping" ForwardDirectLinkParams = "forward_direct_link_params" - + // index SearchIndex = "search_index" AutoUpdateIndex = "auto_update_index" IgnorePaths = "ignore_paths" MaxIndexDepth = "max_index_depth" - // aria2 - Aria2Uri = "aria2_uri" - Aria2Secret = "aria2_secret" + // aria2 + Aria2Uri = "aria2_uri" + Aria2Secret = "aria2_secret" - // single - Token = "token" - IndexProgress = "index_progress" + // single + Token = "token" + IndexProgress = "index_progress" - //Github - GithubClientId = "github_client_id" - GithubClientSecrets = "github_client_secrets" - GithubLoginEnabled = "github_login_enabled" + //SSO + SSOClientId = "sso_client_id" + SSOClientSecret = "sso_client_secret" + SSOLoginEnabled = "sso_login_enabled" + SSOLoginplatform = "sso_login_platform" - // qbittorrent - QbittorrentUrl = "qbittorrent_url" + // qbittorrent + QbittorrentUrl = "qbittorrent_url" ) const ( - UNKNOWN = iota - FOLDER - //OFFICE - VIDEO - AUDIO - TEXT - IMAGE + UNKNOWN = iota + FOLDER + //OFFICE + VIDEO + AUDIO + TEXT + IMAGE ) diff --git a/internal/db/user.go b/internal/db/user.go index 99f556eccf1..497f0905719 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -21,10 +21,10 @@ func GetUserByName(username string) (*model.User, error) { return &user, nil } -func GetUserByGithubID(githubID int) (*model.User, error) { - user := model.User{GithubID: githubID} +func GetUserBySSOID(ssoID string) (*model.User, error) { + user := model.User{SsoID: ssoID} if err := db.Where(user).First(&user).Error; err != nil { - return nil, errors.Wrapf(err, "The Github ID is not associated with a user") + return nil, errors.Wrapf(err, "The single sign on platform is not bound to any users") } return &user, nil } diff --git a/internal/model/setting.go b/internal/model/setting.go index 883a8534d70..f4202ee022c 100644 --- a/internal/model/setting.go +++ b/internal/model/setting.go @@ -8,7 +8,7 @@ const ( GLOBAL ARIA2 INDEX - GITHUB + SSO ) const ( diff --git a/internal/model/user.go b/internal/model/user.go index 7ce4f2d696f..2cde7a545ce 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -33,7 +33,7 @@ type User struct { // 10: can add qbittorrent tasks Permission int32 `json:"permission"` OtpSecret string `json:"-"` - GithubID int `json:"github_id"` + SsoID string `json:"sso_id"` } func (u User) IsGuest() bool { diff --git a/server/handles/auth.go b/server/handles/auth.go index 70a4d299f9e..9bbbf3e83a9 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -101,7 +101,7 @@ func UpdateCurrent(c *gin.Context) { if req.Password != "" { user.Password = req.Password } - user.GithubID = req.GithubID + user.SsoID = req.SsoID if err := op.UpdateUser(user); err != nil { common.ErrorResp(c, err, 500) } else { diff --git a/server/handles/githublogin.go b/server/handles/githublogin.go deleted file mode 100644 index 426be8fa16f..00000000000 --- a/server/handles/githublogin.go +++ /dev/null @@ -1,101 +0,0 @@ -package handles - -import ( - "errors" - "net/url" - "strconv" - - "github.com/alist-org/alist/v3/internal/db" - "github.com/alist-org/alist/v3/pkg/utils" - "github.com/alist-org/alist/v3/server/common" - "github.com/gin-gonic/gin" - "github.com/go-resty/resty/v2" -) - -func GithubLoginRedirect(c *gin.Context) { - method := c.Query("method") - callbackURL := c.Query("callback_url") - withParams := c.Query("with_params") - enabled, err := db.GetSettingItemByKey("github_login_enabled") - clientId, err := db.GetSettingItemByKey("github_client_id") - if err != nil { - common.ErrorResp(c, err, 400) - return - } else if enabled.Value == "true" { - urlValues := url.Values{} - urlValues.Add("client_id", clientId.Value) - if method == "get_github_id" { - urlValues.Add("allow_signup", "true") - } else if method == "github_callback_login" { - urlValues.Add("allow_signup", "false") - } - if method == "" { - common.ErrorStrResp(c, "no method provided", 400) - return - } - if withParams != "" { - urlValues.Add("redirect_uri", common.GetApiUrl(c.Request)+"/api/auth/github_callback"+"?method="+method+"&callback_url="+callbackURL+"&with_params="+withParams) - } else { - urlValues.Add("redirect_uri", common.GetApiUrl(c.Request)+"/api/auth/github_callback"+"?method="+method+"&callback_url="+callbackURL) - } - c.Redirect(302, "https://github.com/login/oauth/authorize?"+urlValues.Encode()) - } else { - common.ErrorStrResp(c, "github Login not enabled", 403) - } -} - -var githubClient = resty.New().SetRetryCount(3) - -func GithubLoginCallback(c *gin.Context) { - argument := c.Query("method") - callbackUrl := c.Query("callback_url") - if argument == "get_github_id" || argument == "github_login" { - enabled, err := db.GetSettingItemByKey("github_login_enabled") - clientId, err := db.GetSettingItemByKey("github_client_id") - clientSecret, err := db.GetSettingItemByKey("github_client_secrets") - if err != nil { - common.ErrorResp(c, err, 400) - return - } else if enabled.Value == "true" { - callbackCode := c.Query("code") - if callbackCode == "" { - common.ErrorStrResp(c, "No code provided", 400) - return - } - resp, err := githubClient.R().SetHeader("content-type", "application/json"). - SetBody(map[string]string{ - "client_id": clientId.Value, - "client_secret": clientSecret.Value, - "code": callbackCode, - "redirect_uri": common.GetApiUrl(c.Request) + "/api/auth/github_callback", - }).Post("https://github.com/login/oauth/access_token") - if err != nil { - common.ErrorResp(c, err, 400) - return - } - accessToken := utils.Json.Get(resp.Body(), "access_token").ToString() - resp, err = githubClient.R().SetHeader("Authorization", "Bearer "+accessToken). - Get("https://api.github.com/user") - ghUserID := utils.Json.Get(resp.Body(), "id").ToInt() - if argument == "get_github_id" { - c.Redirect(302, callbackUrl+"?githubID="+strconv.Itoa(ghUserID)) - } - if argument == "github_login" { - user, err := db.GetUserByGithubID(ghUserID) - if err != nil { - common.ErrorResp(c, err, 400) - } - token, err := common.GenerateToken(user.Username) - withParams := c.Query("with_params") - if withParams == "true" { - c.Redirect(302, callbackUrl+"&token="+token) - } else if withParams == "false" { - c.Redirect(302, callbackUrl+"?token="+token) - } - return - } - } else { - common.ErrorResp(c, errors.New("invalid request"), 500) - } - } -} diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go new file mode 100644 index 00000000000..aa6adde3a72 --- /dev/null +++ b/server/handles/ssologin.go @@ -0,0 +1,181 @@ +package handles + +import ( + "errors" + "fmt" + "net/url" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/go-resty/resty/v2" +) + +func SSOLoginRedirect(c *gin.Context) { + method := c.Query("method") + enabled := setting.GetBool(conf.SSOLoginEnabled) + clientId := setting.GetStr(conf.SSOClientId) + platform := setting.GetStr(conf.SSOLoginplatform) + var r_url string + var redirect_uri string + if enabled { + urlValues := url.Values{} + if method == "" { + common.ErrorStrResp(c, "no method provided", 400) + return + } + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method + urlValues.Add("response_type", "code") + urlValues.Add("redirect_uri", redirect_uri) + urlValues.Add("client_id", clientId) + switch platform { + case "Github": + r_url = "https://github.com/login/oauth/authorize?" + urlValues.Add("scope", "read:user") + case "Microsoft": + r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" + urlValues.Add("scope", "user.read") + urlValues.Add("response_mode", "query") + case "Google": + r_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile") + case "Dingtalk": + r_url = "https://login.dingtalk.com/oauth2/auth?" + urlValues.Add("scope", "openid") + urlValues.Add("prompt", "consent") + urlValues.Add("response_type", "code") + default: + common.ErrorStrResp(c, "invalid platform", 400) + return + } + c.Redirect(302, r_url+urlValues.Encode()) + } else { + common.ErrorStrResp(c, "Single sign-on is not enabled", 403) + } +} + +var ssoClient = resty.New().SetRetryCount(3) + +func SSOLoginCallback(c *gin.Context) { + argument := c.Query("method") + if argument == "get_sso_id" || argument == "sso_get_token" { + enabled := setting.GetBool(conf.SSOLoginEnabled) + clientId := setting.GetStr(conf.SSOClientId) + platform := setting.GetStr(conf.SSOLoginplatform) + clientSecret := setting.GetStr(conf.SSOClientSecret) + var url1, url2, additionalbody, scope, authstring, idstring string + switch platform { + case "Github": + url1 = "https://github.com/login/oauth/access_token" + url2 = "https://api.github.com/user" + additionalbody = "" + authstring = "code" + scope = "read:user" + idstring = "id" + case "Microsoft": + url1 = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + url2 = "https://graph.microsoft.com/v1.0/me" + additionalbody = "&grant_type=authorization_code" + scope = "user.read" + authstring = "code" + idstring = "id" + case "Google": + url1 = "https://oauth2.googleapis.com/token" + url2 = "https://www.googleapis.com/oauth2/v1/userinfo" + additionalbody = "&grant_type=authorization_code" + scope = "https://www.googleapis.com/auth/userinfo.profile" + authstring = "code" + idstring = "id" + case "Dingtalk": + url1 = "https://api.dingtalk.com/v1.0/oauth2/userAccessToken" + url2 = "https://api.dingtalk.com/v1.0/contact/users/me" + authstring = "authCode" + idstring = "unionId" + default: + common.ErrorStrResp(c, "invalid platform", 400) + return + } + if enabled { + callbackCode := c.Query(authstring) + if callbackCode == "" { + common.ErrorStrResp(c, "No code provided", 400) + return + } + var resp *resty.Response + var err error + if platform == "Dingtalk" { + resp, err = ssoClient.R().SetHeader("content-type", "application/json").SetHeader("Accept", "application/json"). + SetBody(map[string]string{ + "clientId": clientId, + "clientSecret": clientSecret, + "code": callbackCode, + "grantType": "authorization_code", + }). + Post(url1) + } else { + resp, err = ssoClient.R().SetHeader("content-type", "application/x-www-form-urlencoded").SetHeader("Accept", "application/json"). + SetBody("client_id=" + clientId + "&client_secret=" + clientSecret + "&code=" + callbackCode + "&redirect_uri=" + common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument + "&scope=" + scope + additionalbody). + Post(url1) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if platform == "Dingtalk" { + accessToken := utils.Json.Get(resp.Body(), "accessToken").ToString() + resp, err = ssoClient.R().SetHeader("x-acs-dingtalk-access-token", accessToken). + Get(url2) + } else { + accessToken := utils.Json.Get(resp.Body(), "access_token").ToString() + resp, err = ssoClient.R().SetHeader("Authorization", "Bearer "+accessToken). + Get(url2) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + UserID := utils.Json.Get(resp.Body(), idstring).ToString() + if UserID == "0" { + common.ErrorResp(c, errors.New("error occured"), 400) + return + } + if argument == "get_sso_id" { + html := fmt.Sprintf(` + + + + `, UserID) + c.Data(200, "text/html; charset=utf-8", []byte(html)) + return + } + if argument == "sso_get_token" { + user, err := db.GetUserBySSOID(UserID) + if err != nil { + common.ErrorResp(c, err, 400) + } + token, err := common.GenerateToken(user.Username) + if err != nil { + common.ErrorResp(c, err, 400) + } + html := fmt.Sprintf(` + + + + `, token) + c.Data(200, "text/html; charset=utf-8", []byte(html)) + return + } + } else { + common.ErrorResp(c, errors.New("invalid request"), 500) + } + } +} diff --git a/server/router.go b/server/router.go index 39c2cfbd828..19f3615e4b3 100644 --- a/server/router.go +++ b/server/router.go @@ -43,8 +43,8 @@ func Init(e *gin.Engine) { auth.POST("/auth/2fa/verify", handles.Verify2FA) // github auth - api.GET("/auth/github", handles.GithubLoginRedirect) - api.GET("/auth/github_callback", handles.GithubLoginCallback) + api.GET("/auth/sso", handles.SSOLoginRedirect) + api.GET("/auth/sso_callback", handles.SSOLoginCallback) // no need auth public := api.Group("/public")