Skip to content

Commit f7af8ca

Browse files
snuggie12michaelliau
authored andcommitted
Move unique functionality into getGroups to reduce calls to google (dexidp#2628)
Signed-off-by: Matt Hoey <[email protected]>
1 parent 88c97c5 commit f7af8ca

File tree

2 files changed

+118
-31
lines changed

2 files changed

+118
-31
lines changed

connector/google/google.go

+19-24
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
227227

228228
var groups []string
229229
if s.Groups && c.adminSrv != nil {
230-
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership)
230+
checkedGroups := make(map[string]struct{})
231+
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
231232
if err != nil {
232233
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
233234
}
@@ -253,7 +254,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
253254

254255
// getGroups creates a connection to the admin directory service and lists
255256
// all groups the user is a member of
256-
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool) ([]string, error) {
257+
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
257258
var userGroups []string
258259
var err error
259260
groupsList := &admin.Groups{}
@@ -265,26 +266,33 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
265266
}
266267

267268
for _, group := range groupsList.Groups {
269+
if _, exists := checkedGroups[group.Email]; exists {
270+
continue
271+
}
272+
273+
checkedGroups[group.Email] = struct{}{}
268274
// TODO (joelspeed): Make desired group key configurable
269275
userGroups = append(userGroups, group.Email)
270276

271-
// getGroups takes a user's email/alias as well as a group's email/alias
272-
if fetchTransitiveGroupMembership {
273-
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership)
274-
if err != nil {
275-
return nil, fmt.Errorf("could not list transitive groups: %v", err)
276-
}
277+
if !fetchTransitiveGroupMembership {
278+
continue
279+
}
277280

278-
userGroups = append(userGroups, transitiveGroups...)
281+
// getGroups takes a user's email/alias as well as a group's email/alias
282+
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
283+
if err != nil {
284+
return nil, fmt.Errorf("could not list transitive groups: %v", err)
279285
}
286+
287+
userGroups = append(userGroups, transitiveGroups...)
280288
}
281289

282290
if groupsList.NextPageToken == "" {
283291
break
284292
}
285293
}
286294

287-
return uniqueGroups(userGroups), nil
295+
return userGroups, nil
288296
}
289297

290298
// createDirectoryService sets up super user impersonation and creates an admin client for calling
@@ -316,7 +324,7 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log
316324
}
317325
config, err := google.JWTConfigFromJSON(jsonCredentials, admin.AdminDirectoryGroupReadonlyScope)
318326
if err != nil {
319-
return nil, fmt.Errorf("unable to parse credentials to config: %v", err)
327+
return nil, fmt.Errorf("unable to parse client secret file to config: %v", err)
320328
}
321329

322330
// Only attempt impersonation when there is a user configured
@@ -326,16 +334,3 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log
326334

327335
return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx)))
328336
}
329-
330-
// uniqueGroups returns the unique groups of a slice
331-
func uniqueGroups(groups []string) []string {
332-
keys := make(map[string]struct{})
333-
unique := []string{}
334-
for _, group := range groups {
335-
if _, exists := keys[group]; !exists {
336-
keys[group] = struct{}{}
337-
unique = append(unique, group)
338-
}
339-
}
340-
return unique
341-
}

connector/google/google_test.go

+99-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package google
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"net/http"
@@ -10,17 +11,38 @@ import (
1011

1112
"github.com/sirupsen/logrus"
1213
"github.com/stretchr/testify/assert"
14+
admin "google.golang.org/api/admin/directory/v1"
15+
"google.golang.org/api/option"
16+
)
17+
18+
var (
19+
// groups_0
20+
// ┌───────┤
21+
// groups_2 groups_1
22+
// │ ├────────┐
23+
// └── user_1 user_2
24+
testGroups = map[string][]*admin.Group{
25+
26+
27+
28+
29+
30+
}
31+
callCounter = make(map[string]int)
1332
)
1433

1534
func testSetup(t *testing.T) *httptest.Server {
1635
mux := http.NewServeMux()
17-
// TODO: mock calls
18-
// mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) {
19-
// w.Header().Add("Content-Type", "application/json")
20-
// json.NewEncoder(w).Encode(&admin.Groups{
21-
// Groups: []*admin.Group{},
22-
// })
23-
// })
36+
37+
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
38+
w.Header().Add("Content-Type", "application/json")
39+
userKey := r.URL.Query().Get("userKey")
40+
if groups, ok := testGroups[userKey]; ok {
41+
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
42+
callCounter[userKey]++
43+
}
44+
})
45+
2446
return httptest.NewServer(mux)
2547
}
2648

@@ -144,3 +166,73 @@ func TestOpen(t *testing.T) {
144166
})
145167
}
146168
}
169+
170+
func TestGetGroups(t *testing.T) {
171+
ts := testSetup(t)
172+
defer ts.Close()
173+
174+
serviceAccountFilePath, err := tempServiceAccountKey()
175+
assert.Nil(t, err)
176+
177+
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
178+
conn, err := newConnector(&Config{
179+
ClientID: "testClient",
180+
ClientSecret: "testSecret",
181+
RedirectURI: ts.URL + "/callback",
182+
Scopes: []string{"openid", "groups"},
183+
AdminEmail: "[email protected]",
184+
}, ts.URL)
185+
assert.Nil(t, err)
186+
187+
conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
188+
assert.Nil(t, err)
189+
type testCase struct {
190+
userKey string
191+
fetchTransitiveGroupMembership bool
192+
shouldErr bool
193+
expectedGroups []string
194+
}
195+
196+
for name, testCase := range map[string]testCase{
197+
"user1_non_transitive_lookup": {
198+
userKey: "[email protected]",
199+
fetchTransitiveGroupMembership: false,
200+
shouldErr: false,
201+
expectedGroups: []string{"[email protected]", "[email protected]"},
202+
},
203+
"user1_transitive_lookup": {
204+
userKey: "[email protected]",
205+
fetchTransitiveGroupMembership: true,
206+
shouldErr: false,
207+
expectedGroups: []string{"[email protected]", "[email protected]", "[email protected]"},
208+
},
209+
"user2_non_transitive_lookup": {
210+
userKey: "[email protected]",
211+
fetchTransitiveGroupMembership: false,
212+
shouldErr: false,
213+
expectedGroups: []string{"[email protected]"},
214+
},
215+
"user2_transitive_lookup": {
216+
userKey: "[email protected]",
217+
fetchTransitiveGroupMembership: true,
218+
shouldErr: false,
219+
expectedGroups: []string{"[email protected]", "[email protected]"},
220+
},
221+
} {
222+
testCase := testCase
223+
callCounter = map[string]int{}
224+
t.Run(name, func(t *testing.T) {
225+
assert := assert.New(t)
226+
lookup := make(map[string]struct{})
227+
228+
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
229+
if testCase.shouldErr {
230+
assert.NotNil(err)
231+
} else {
232+
assert.Nil(err)
233+
}
234+
assert.ElementsMatch(testCase.expectedGroups, groups)
235+
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
236+
})
237+
}
238+
}

0 commit comments

Comments
 (0)