From 4d7cf566ea96c3c317946d15a9e01a0fe5012f51 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sat, 22 Dec 2018 15:27:03 -0800 Subject: [PATCH] Fetch groups for GSuite SSO. Fixes #2455 This commit adds support for fetching groups for GSuite SSO logins via OIDC connector interface. If OIDC connector has a special scope: `https://www.googleapis.com/auth/admin.directory.group.readonly` teleport will fetch user's group membership and populate groups claim. --- constants.go | 9 ++++ lib/auth/oidc.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/constants.go b/constants.go index 71ee90e1bfee3..20456fdee17bc 100644 --- a/constants.go +++ b/constants.go @@ -377,6 +377,15 @@ const ( TraitInternalKubeGroupsVariable = "{{internal.kubernetes_groups}}" ) +const ( + // GSuiteIssuerURL is issuer URL used for GSuite provider + GSuiteIssuerURL = "https://accounts.google.com" + // GSuiteGroupsEndpoint is gsuite API endpoint + GSuiteGroupsEndpoint = "https://www.googleapis.com/admin/directory/v1/groups" + // GSuiteGroupsScope is a scope to get access to admin groups API + GSuiteGroupsScope = "https://www.googleapis.com/auth/admin.directory.group.readonly" +) + // SCP is Secure Copy. const SCP = "scp" diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index f137897d3254b..8d5822891e1b5 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -19,6 +19,7 @@ package auth import ( "encoding/json" "fmt" + "io/ioutil" "net/http" "net/url" "time" @@ -175,7 +176,7 @@ func (a *AuthServer) validateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, } // extract claims from both the id token and the userinfo endpoint and merge them - claims, err := a.getClaims(oidcClient, connector.GetIssuerURL(), code) + claims, err := a.getClaims(oidcClient, connector.GetIssuerURL(), connector.GetScope(), code) if err != nil { return nil, trace.OAuth2( oauth2.ErrorUnsupportedResponseType, "unable to construct claims", q) @@ -476,6 +477,111 @@ func claimsFromUserInfo(oidcClient *oidc.Client, issuerURL string, accessToken s return claims, nil } +func (a *AuthServer) claimsFromGSuite(oidcClient *oidc.Client, issuerURL string, userEmail string, accessToken string) (jose.Claims, error) { + err := isHTTPS(issuerURL) + if err != nil { + return nil, trace.Wrap(err) + } + + oac, err := oidcClient.OAuthClient() + if err != nil { + return nil, trace.Wrap(err) + } + hc := oac.HttpClient() + + u, err := url.Parse(teleport.GSuiteGroupsEndpoint) + if err != nil { + return nil, trace.Wrap(err) + } + + fetchGroups := func(pageToken string) (*gsuiteGroups, error) { + q := u.Query() + q.Set("userKey", userEmail) + if pageToken != "" { + q.Set("pageToken", pageToken) + } + u.RawQuery = q.Encode() + endpoint := u.String() + + log.Debugf("Fetching OIDC claims from GSuite groups endpoint: %q.", endpoint) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, trace.Wrap(err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + + resp, err := hc.Do(req) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, trace.Wrap(err) + } + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, trace.AccessDenied("bad status code: %v %v", resp.StatusCode, string(bytes)) + } + + var response gsuiteGroups + if err := json.Unmarshal(bytes, &response); err != nil { + return nil, trace.BadParameter("failed to parse response: %v", err) + } + return &response, nil + } + + count := 0 + var groups []string + var nextPageToken string +collect: + for { + if count > MaxPages { + warningMessage := "Truncating list of teams used to populate claims: " + + "hit maximum number pages that can be fetched from GitHub." + + // Print warning to Teleport logs as well as the Audit Log. + log.Warnf(warningMessage) + a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{ + events.LoginMethod: events.LoginMethodOIDC, + events.AuthAttemptMessage: warningMessage, + }) + break collect + } + response, err := fetchGroups(nextPageToken) + if err != nil { + return nil, trace.Wrap(err) + } + groups = append(groups, response.groups()...) + if response.NextPageToken == "" { + break collect + } + count++ + nextPageToken = response.NextPageToken + } + + return jose.Claims{"groups": groups}, nil +} + +type gsuiteGroups struct { + NextPageToken string `json:"nextPageToken"` + Groups []gsuiteGroup `json:"groups"` +} + +func (g gsuiteGroups) groups() []string { + groups := make([]string, len(g.Groups)) + for i, group := range g.Groups { + groups[i] = group.Email + } + return groups +} + +type gsuiteGroup struct { + Email string `json:"email"` +} + // mergeClaims merges b into a. func mergeClaims(a jose.Claims, b jose.Claims) (jose.Claims, error) { for k, v := range b { @@ -489,7 +595,7 @@ func mergeClaims(a jose.Claims, b jose.Claims) (jose.Claims, error) { } // getClaims gets claims from ID token and UserInfo and returns UserInfo claims merged into ID token claims. -func (a *AuthServer) getClaims(oidcClient *oidc.Client, issuerURL string, code string) (jose.Claims, error) { +func (a *AuthServer) getClaims(oidcClient *oidc.Client, issuerURL string, scope []string, code string) (jose.Claims, error) { var err error oac, err := oidcClient.OAuthClient() @@ -545,6 +651,30 @@ func (a *AuthServer) getClaims(oidcClient *oidc.Client, issuerURL string, code s return nil, trace.Wrap(err) } + // for GSuite users, fetch extra data from the proprietary google API + // only if scope includes admin groups readonly scope + if issuerURL == teleport.GSuiteIssuerURL && utils.SliceContainsStr(scope, teleport.GSuiteGroupsScope) { + email, _, err := claims.StringClaim("email") + if err != nil { + return nil, trace.Wrap(err) + } + gsuiteClaims, err := a.claimsFromGSuite(oidcClient, issuerURL, email, t.AccessToken) + if err != nil { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + log.Debugf("Found no GSuite claims.") + } else { + if gsuiteClaims != nil { + log.Debugf("Got GSuiteclaims: %v.", gsuiteClaims) + } + claims, err = mergeClaims(claims, gsuiteClaims) + if err != nil { + return nil, trace.Wrap(err) + } + } + } + return claims, nil }