Skip to content

Commit f8409dd

Browse files
committed
revert google.go prior to adding my own tests and add better ones
1 parent ff9f1ad commit f8409dd

File tree

2 files changed

+131
-45
lines changed

2 files changed

+131
-45
lines changed

connector/google/google.go

+31-38
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
7171
scopes = append(scopes, "profile", "email")
7272
}
7373

74-
srv, err := createDirectoryService(c.ServiceAccountFilePath, c.AdminEmail, logger)
74+
srv, err := createDirectoryService(c.ServiceAccountFilePath, c.AdminEmail)
7575
if err != nil {
7676
cancel()
7777
return nil, fmt.Errorf("could not create directory service: %v", err)
@@ -220,7 +220,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
220220
var groups []string
221221
if s.Groups && c.adminSrv != nil {
222222
checkedGroups := make(map[string]struct{})
223-
groups, err = getGroups(c.getGroupsList, claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
223+
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
224224
if err != nil {
225225
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
226226
}
@@ -244,22 +244,15 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
244244
return identity, nil
245245
}
246246

247-
// getGroupsList returns a list of Groups from google
248-
func (c *googleConnector) getGroupsList(email string, nextPageToken string) (*admin.Groups, error) {
249-
groupsList, err := c.adminSrv.Groups.List().
250-
UserKey(email).PageToken(nextPageToken).Do()
251-
return groupsList, err
252-
}
253-
254247
// getGroups creates a connection to the admin directory service and lists
255248
// all groups the user is a member of
256-
// to test functionality, first parameter is the function you want to run to fetch groups
257-
func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
249+
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
258250
var userGroups []string
259251
var err error
260252
groupsList := &admin.Groups{}
261253
for {
262-
groupsList, err = getGroupsListFunc(email, groupsList.NextPageToken)
254+
groupsList, err = c.adminSrv.Groups.List().
255+
UserKey(email).PageToken(groupsList.NextPageToken).Do()
263256
if err != nil {
264257
return nil, fmt.Errorf("could not list groups: %v", err)
265258
}
@@ -278,7 +271,7 @@ func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), em
278271
}
279272

280273
// getGroups takes a user's email/alias as well as a group's email/alias
281-
transitiveGroups, err := getGroups(getGroupsListFunc, group.Email, fetchTransitiveGroupMembership, checkedGroups)
274+
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
282275
if err != nil {
283276
return nil, fmt.Errorf("could not list transitive groups: %v", err)
284277
}
@@ -294,35 +287,35 @@ func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), em
294287
return userGroups, nil
295288
}
296289

297-
// createDirectoryService sets up super user impersonation and creates an admin client for calling
298-
// the google admin api. If no serviceAccountFilePath is defined, the application default credential
299-
// is used.
300-
func createDirectoryService(serviceAccountFilePath, email string, logger log.Logger) (*admin.Service, error) {
301-
if email == "" {
302-
return nil, fmt.Errorf("directory service requires adminEmail")
290+
// createDirectoryService loads a google service account credentials file,
291+
// sets up super user impersonation and creates an admin client for calling
292+
// the google admin api
293+
func createDirectoryService(serviceAccountFilePath string, email string) (*admin.Service, error) {
294+
if serviceAccountFilePath == "" && email == "" {
295+
return nil, nil
303296
}
304-
305-
var jsonCredentials []byte
306-
var err error
307-
308-
ctx := context.Background()
309-
if serviceAccountFilePath == "" {
310-
logger.Warn("the application default credential is used since the service account file path is not used")
311-
credential, err := google.FindDefaultCredentials(ctx)
312-
if err != nil {
313-
return nil, fmt.Errorf("failed to fetch application default credentials: %w", err)
314-
}
315-
jsonCredentials = credential.JSON
316-
} else {
317-
jsonCredentials, err = os.ReadFile(serviceAccountFilePath)
318-
if err != nil {
319-
return nil, fmt.Errorf("error reading credentials from file: %v", err)
320-
}
297+
if serviceAccountFilePath == "" || email == "" {
298+
return nil, fmt.Errorf("directory service requires both serviceAccountFilePath and adminEmail")
299+
}
300+
jsonCredentials, err := os.ReadFile(serviceAccountFilePath)
301+
if err != nil {
302+
return nil, fmt.Errorf("error reading credentials from file: %v", err)
321303
}
304+
322305
config, err := google.JWTConfigFromJSON(jsonCredentials, admin.AdminDirectoryGroupReadonlyScope)
323306
if err != nil {
324-
return nil, fmt.Errorf("unable to parse credentials to config: %v", err)
307+
return nil, fmt.Errorf("unable to parse client secret file to config: %v", err)
325308
}
309+
310+
// Impersonate an admin. This is mandatory for the admin APIs.
326311
config.Subject = email
327-
return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx)))
312+
313+
ctx := context.Background()
314+
client := config.Client(ctx)
315+
316+
srv, err := admin.NewService(ctx, option.WithHTTPClient(client))
317+
if err != nil {
318+
return nil, fmt.Errorf("unable to create directory service %v", err)
319+
}
320+
return srv, nil
328321
}

connector/google/google_test.go

+100-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package google
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"net/http"
@@ -10,17 +11,39 @@ import (
1011

1112
"github.com/sirupsen/logrus"
1213
"github.com/stretchr/testify/assert"
14+
15+
admin "google.golang.org/api/admin/directory/v1"
16+
"google.golang.org/api/option"
17+
)
18+
19+
var (
20+
// groups_0
21+
// ┌───────┤
22+
// groups_2 groups_1
23+
// │ ├────────┐
24+
// └── user_1 user_2
25+
testGroups = map[string][]*admin.Group{
26+
27+
28+
29+
30+
31+
}
32+
callCounter = make(map[string]int)
1333
)
1434

1535
func testSetup(t *testing.T) *httptest.Server {
1636
mux := http.NewServeMux()
17-
// TODO: mock calls
18-
// mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) {
19-
// w.Header().Add("Content-Type", "application/json")
20-
// json.NewEncoder(w).Encode(&admin.Groups{
21-
// Groups: []*admin.Group{},
22-
// })
23-
// })
37+
38+
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
39+
w.Header().Add("Content-Type", "application/json")
40+
userKey := r.URL.Query().Get("userKey")
41+
if groups, ok := testGroups[userKey]; ok {
42+
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
43+
callCounter[userKey]++
44+
}
45+
})
46+
2447
return httptest.NewServer(mux)
2548
}
2649

@@ -143,3 +166,73 @@ func TestOpen(t *testing.T) {
143166
})
144167
}
145168
}
169+
170+
func TestGetGroups(t *testing.T) {
171+
ts := testSetup(t)
172+
defer ts.Close()
173+
174+
serviceAccountFilePath, err := tempServiceAccountKey()
175+
assert.Nil(t, err)
176+
177+
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
178+
conn, err := newConnector(&Config{
179+
ClientID: "testClient",
180+
ClientSecret: "testSecret",
181+
RedirectURI: ts.URL + "/callback",
182+
Scopes: []string{"openid", "groups"},
183+
AdminEmail: "[email protected]",
184+
}, ts.URL)
185+
assert.Nil(t, err)
186+
187+
conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
188+
assert.Nil(t, err)
189+
type testCase struct {
190+
userKey string
191+
fetchTransitiveGroupMembership bool
192+
shouldErr bool
193+
expectedGroups []string
194+
}
195+
196+
for name, testCase := range map[string]testCase{
197+
"user1_non_transitive_lookup": {
198+
userKey: "[email protected]",
199+
fetchTransitiveGroupMembership: false,
200+
shouldErr: false,
201+
expectedGroups: []string{"[email protected]", "[email protected]"},
202+
},
203+
"user1_transitive_lookup": {
204+
userKey: "[email protected]",
205+
fetchTransitiveGroupMembership: true,
206+
shouldErr: false,
207+
expectedGroups: []string{"[email protected]", "[email protected]", "[email protected]"},
208+
},
209+
"user2_non_transitive_lookup": {
210+
userKey: "[email protected]",
211+
fetchTransitiveGroupMembership: false,
212+
shouldErr: false,
213+
expectedGroups: []string{"[email protected]"},
214+
},
215+
"user2_transitive_lookup": {
216+
userKey: "[email protected]",
217+
fetchTransitiveGroupMembership: true,
218+
shouldErr: false,
219+
expectedGroups: []string{"[email protected]", "[email protected]"},
220+
},
221+
} {
222+
testCase := testCase
223+
callCounter = map[string]int{}
224+
t.Run(name, func(t *testing.T) {
225+
assert := assert.New(t)
226+
lookup := make(map[string]struct{})
227+
228+
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
229+
if testCase.shouldErr {
230+
assert.NotNil(err)
231+
} else {
232+
assert.Nil(err)
233+
}
234+
assert.ElementsMatch(testCase.expectedGroups, groups)
235+
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
236+
})
237+
}
238+
}

0 commit comments

Comments
 (0)