1
1
/*
2
- Copyright 2017 Gravitational, Inc.
2
+ Copyright 2017-2019 Gravitational, Inc.
3
3
4
4
Licensed under the Apache License, Version 2.0 (the "License");
5
5
you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package auth
19
19
import (
20
20
"encoding/json"
21
21
"fmt"
22
+ "io/ioutil"
22
23
"net/http"
23
24
"net/url"
24
25
"time"
@@ -29,6 +30,7 @@ import (
29
30
"github.com/gravitational/teleport/lib/services"
30
31
"github.com/gravitational/teleport/lib/utils"
31
32
33
+ phttp "github.com/coreos/go-oidc/http"
32
34
"github.com/coreos/go-oidc/jose"
33
35
"github.com/coreos/go-oidc/oauth2"
34
36
"github.com/coreos/go-oidc/oidc"
@@ -175,7 +177,7 @@ func (a *AuthServer) validateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse,
175
177
}
176
178
177
179
// extract claims from both the id token and the userinfo endpoint and merge them
178
- claims , err := a .getClaims (oidcClient , connector .GetIssuerURL (), code )
180
+ claims , err := a .getClaims (oidcClient , connector .GetIssuerURL (), connector . GetScope (), code )
179
181
if err != nil {
180
182
return nil , trace .OAuth2 (
181
183
oauth2 .ErrorUnsupportedResponseType , "unable to construct claims" , q )
@@ -476,6 +478,138 @@ func claimsFromUserInfo(oidcClient *oidc.Client, issuerURL string, accessToken s
476
478
return claims , nil
477
479
}
478
480
481
+ func (a * AuthServer ) claimsFromGSuite (oidcClient * oidc.Client , issuerURL string , userEmail string , accessToken string ) (jose.Claims , error ) {
482
+ client , err := a .newGsuiteClient (oidcClient , issuerURL , userEmail , accessToken )
483
+ if err != nil {
484
+ return nil , trace .Wrap (err )
485
+ }
486
+ return client .fetchGroups ()
487
+ }
488
+
489
+ func (a * AuthServer ) newGsuiteClient (oidcClient * oidc.Client , issuerURL string , userEmail string , accessToken string ) (* gsuiteClient , error ) {
490
+ err := isHTTPS (issuerURL )
491
+ if err != nil {
492
+ return nil , trace .Wrap (err )
493
+ }
494
+
495
+ oac , err := oidcClient .OAuthClient ()
496
+ if err != nil {
497
+ return nil , trace .Wrap (err )
498
+ }
499
+
500
+ u , err := url .Parse (teleport .GSuiteGroupsEndpoint )
501
+ if err != nil {
502
+ return nil , trace .Wrap (err )
503
+ }
504
+
505
+ return & gsuiteClient {
506
+ Client : oac .HttpClient (),
507
+ url : * u ,
508
+ userEmail : userEmail ,
509
+ accessToken : accessToken ,
510
+ auditLog : a ,
511
+ }, nil
512
+ }
513
+
514
+ type gsuiteClient struct {
515
+ phttp.Client
516
+ url url.URL
517
+ userEmail string
518
+ accessToken string
519
+ auditLog events.IAuditLog
520
+ }
521
+
522
+ // fetchGroups fetches GSuite groups a user belongs to and returns
523
+ // "groups" claim with
524
+ func (g * gsuiteClient ) fetchGroups () (jose.Claims , error ) {
525
+ count := 0
526
+ var groups []string
527
+ var nextPageToken string
528
+ collect:
529
+ for {
530
+ if count > MaxPages {
531
+ warningMessage := "Truncating list of teams used to populate claims: " +
532
+ "hit maximum number pages that can be fetched from GSuite."
533
+
534
+ // Print warning to Teleport logs as well as the Audit Log.
535
+ log .Warnf (warningMessage )
536
+ g .auditLog .EmitAuditEvent (events .UserLoginEvent , events.EventFields {
537
+ events .LoginMethod : events .LoginMethodOIDC ,
538
+ events .AuthAttemptMessage : warningMessage ,
539
+ })
540
+ break collect
541
+ }
542
+ response , err := g .fetchGroupsPage (nextPageToken )
543
+ if err != nil {
544
+ return nil , trace .Wrap (err )
545
+ }
546
+ groups = append (groups , response .groups ()... )
547
+ if response .NextPageToken == "" {
548
+ break collect
549
+ }
550
+ count ++
551
+ nextPageToken = response .NextPageToken
552
+ }
553
+ return jose.Claims {"groups" : groups }, nil
554
+ }
555
+
556
+ func (g * gsuiteClient ) fetchGroupsPage (pageToken string ) (* gsuiteGroups , error ) {
557
+ // copy URL to avoid modifying the same url
558
+ // with query parameters
559
+ u := g .url
560
+ q := u .Query ()
561
+ q .Set ("userKey" , g .userEmail )
562
+ if pageToken != "" {
563
+ q .Set ("pageToken" , pageToken )
564
+ }
565
+ u .RawQuery = q .Encode ()
566
+ endpoint := u .String ()
567
+
568
+ log .Debugf ("Fetching OIDC claims from GSuite groups endpoint: %q." , endpoint )
569
+
570
+ req , err := http .NewRequest ("GET" , endpoint , nil )
571
+ if err != nil {
572
+ return nil , trace .Wrap (err )
573
+ }
574
+ req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , g .accessToken ))
575
+
576
+ resp , err := g .Do (req )
577
+ if err != nil {
578
+ return nil , trace .Wrap (err )
579
+ }
580
+ defer resp .Body .Close ()
581
+
582
+ bytes , err := ioutil .ReadAll (resp .Body )
583
+ if err != nil {
584
+ return nil , trace .Wrap (err )
585
+ }
586
+ if resp .StatusCode < 200 || resp .StatusCode > 299 {
587
+ return nil , trace .AccessDenied ("bad status code: %v %v" , resp .StatusCode , string (bytes ))
588
+ }
589
+ var response gsuiteGroups
590
+ if err := json .Unmarshal (bytes , & response ); err != nil {
591
+ return nil , trace .BadParameter ("failed to parse response: %v" , err )
592
+ }
593
+ return & response , nil
594
+ }
595
+
596
+ type gsuiteGroups struct {
597
+ NextPageToken string `json:"nextPageToken"`
598
+ Groups []gsuiteGroup `json:"groups"`
599
+ }
600
+
601
+ func (g gsuiteGroups ) groups () []string {
602
+ groups := make ([]string , len (g .Groups ))
603
+ for i , group := range g .Groups {
604
+ groups [i ] = group .Email
605
+ }
606
+ return groups
607
+ }
608
+
609
+ type gsuiteGroup struct {
610
+ Email string `json:"email"`
611
+ }
612
+
479
613
// mergeClaims merges b into a.
480
614
func mergeClaims (a jose.Claims , b jose.Claims ) (jose.Claims , error ) {
481
615
for k , v := range b {
@@ -489,7 +623,7 @@ func mergeClaims(a jose.Claims, b jose.Claims) (jose.Claims, error) {
489
623
}
490
624
491
625
// getClaims gets claims from ID token and UserInfo and returns UserInfo claims merged into ID token claims.
492
- func (a * AuthServer ) getClaims (oidcClient * oidc.Client , issuerURL string , code string ) (jose.Claims , error ) {
626
+ func (a * AuthServer ) getClaims (oidcClient * oidc.Client , issuerURL string , scope [] string , code string ) (jose.Claims , error ) {
493
627
var err error
494
628
495
629
oac , err := oidcClient .OAuthClient ()
@@ -545,6 +679,30 @@ func (a *AuthServer) getClaims(oidcClient *oidc.Client, issuerURL string, code s
545
679
return nil , trace .Wrap (err )
546
680
}
547
681
682
+ // for GSuite users, fetch extra data from the proprietary google API
683
+ // only if scope includes admin groups readonly scope
684
+ if issuerURL == teleport .GSuiteIssuerURL && utils .SliceContainsStr (scope , teleport .GSuiteGroupsScope ) {
685
+ email , _ , err := claims .StringClaim ("email" )
686
+ if err != nil {
687
+ return nil , trace .Wrap (err )
688
+ }
689
+ gsuiteClaims , err := a .claimsFromGSuite (oidcClient , issuerURL , email , t .AccessToken )
690
+ if err != nil {
691
+ if ! trace .IsNotFound (err ) {
692
+ return nil , trace .Wrap (err )
693
+ }
694
+ log .Debugf ("Found no GSuite claims." )
695
+ } else {
696
+ if gsuiteClaims != nil {
697
+ log .Debugf ("Got GSuiteclaims: %v." , gsuiteClaims )
698
+ }
699
+ claims , err = mergeClaims (claims , gsuiteClaims )
700
+ if err != nil {
701
+ return nil , trace .Wrap (err )
702
+ }
703
+ }
704
+ }
705
+
548
706
return claims , nil
549
707
}
550
708
0 commit comments