Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions service/entityresolution/keycloak/v2/entity_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger)
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache)
return connect.NewResponse(&resp), err
}

Expand All @@ -99,7 +99,7 @@
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger)
resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache)
return connect.NewResponse(&resp), err
}

Expand Down Expand Up @@ -137,11 +137,12 @@
kcConfig Config,
connector *Connector,
logger *logger.Logger,
svcCache *cache.Cache,
) (entityresolutionV2.CreateEntityChainsFromTokensResponse, error) {
entityChains := []*entity.EntityChain{}
// for each token in the tokens form an entity chain
for _, tok := range req.GetTokens() {
entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger)
entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger, svcCache)
if err != nil {
return entityresolutionV2.CreateEntityChainsFromTokensResponse{}, err
}
Expand All @@ -152,7 +153,7 @@
}

func EntityResolution(ctx context.Context,
req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger,
req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache,
) (entityresolutionV2.ResolveEntitiesResponse, error) {
payload := req.GetEntities() // connector is now passed in

Expand All @@ -167,9 +168,7 @@
case *entity.Entity_ClientId:
logger.Debug("looking up", slog.Any("type", ident.GetEntityType()), slog.String("client_id", ident.GetClientId()))
clientID := ident.GetClientId()
clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
ClientID: &clientID,
})
clients, err := retrieveClients(ctx, logger, clientID, kcConfig.Realm, svcCache, connector)
if err != nil {
logger.Error("error getting client info", slog.String("error", err.Error()))
return entityresolutionV2.ResolveEntitiesResponse{},
Expand Down Expand Up @@ -220,7 +219,7 @@
}

var jsonEntities []*structpb.Struct
users, err := connector.client.GetUsers(ctx, connector.token.AccessToken, kcConfig.Realm, getUserParams)
users, err := retrieveUsers(ctx, logger, getUserParams, kcConfig.Realm, svcCache, connector)
switch {
case err != nil:
logger.Error(err.Error())
Expand All @@ -236,12 +235,7 @@
logger.Error("no user found for", slog.Any("entity", ident))
if ident.GetEmailAddress() != "" { //nolint:nestif // this case has many possible outcomes to handle
// try by group
groups, groupErr := connector.client.GetGroups(
ctx,
connector.token.AccessToken,
kcConfig.Realm,
gocloak.GetGroupsParams{Search: func() *string { t := ident.GetEmailAddress(); return &t }()},
)
groups, groupErr := retrieveGroupsByEmail(ctx, logger, ident.GetEmailAddress(), kcConfig.Realm, svcCache, connector)
switch {
case groupErr != nil:
logger.Error("error getting group", slog.String("group", groupErr.Error()))
Expand All @@ -250,7 +244,7 @@
case len(groups) == 1:
logger.Info("group found for", slog.String("entity", ident.String()))
group := groups[0]
expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger)
expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger, svcCache)
if exErr != nil {
return entityresolutionV2.ResolveEntitiesResponse{},
connect.NewError(connect.CodeNotFound, ErrNotFound)
Expand Down Expand Up @@ -353,14 +347,13 @@
return genericMap, nil
}

func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger) ([]*gocloak.User, error) {
func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger, svcCache *cache.Cache) ([]*gocloak.User, error) {
logger.Info("expanding group", slog.String("groupID", groupID))
var entityRepresentations []*gocloak.User

grp, err := kcConnector.client.GetGroup(ctx, kcConnector.token.AccessToken, kcConfig.Realm, groupID)
grp, err := retrieveGroupByID(ctx, logger, groupID, kcConfig.Realm, svcCache, kcConnector)
if err == nil {
grpMembers, memberErr := kcConnector.client.GetGroupMembers(ctx, kcConnector.token.AccessToken, kcConfig.Realm,
*grp.ID, gocloak.GetGroupsParams{})
grpMembers, memberErr := retrieveGroupMembers(ctx, logger, *grp.ID, kcConfig.Realm, svcCache, kcConnector)
if memberErr == nil {
logger.Debug("adding members", slog.Int("amount", len(grpMembers)), slog.String("from group", *grp.Name))
for i := 0; i < len(grpMembers); i++ {
Expand All @@ -377,7 +370,7 @@
return entityRepresentations, nil
}

func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger) ([]*entity.Entity, error) {
func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger, svcCache *cache.Cache) ([]*entity.Entity, error) {
token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false))
if err != nil {
return nil, errors.New("error parsing jwt " + err.Error())
Expand Down Expand Up @@ -416,7 +409,7 @@

// double check for service account
if strings.HasPrefix(extractedValueUsernameCasted, serviceAccountUsernamePrefix) {
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger)
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger, svcCache)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -445,7 +438,7 @@
return entities, nil
}

func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger) (string, error) {
func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache) (string, error) {

Check failure on line 441 in service/entityresolution/keycloak/v2/entity_resolution.go

View workflow job for this annotation

GitHub Actions / go (service)

unused-parameter: parameter 'svcCache' seems to be unused, consider removing or renaming it as _ (revive)
expectedClientName := strings.TrimPrefix(username, serviceAccountUsernamePrefix)

clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
Expand Down
34 changes: 17 additions & 17 deletions service/entityresolution/keycloak/v2/entity_resolution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func Test_KCEntityResolutionByClientId(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)
_ = json.NewEncoder(os.Stdout).Encode(&resp)
Expand Down Expand Up @@ -236,7 +236,7 @@ func Test_KCEntityResolutionByEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -276,7 +276,7 @@ func Test_KCEntityResolutionByUsername(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -321,7 +321,7 @@ func Test_KCEntityResolutionByGroupEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -362,7 +362,7 @@ func Test_KCEntityResolutionNotFoundError(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.Error(t, reserr)
assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp)
Expand All @@ -385,7 +385,7 @@ func Test_JwtClientAndUsernameClientCredentials(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -408,7 +408,7 @@ func Test_JwtClientAndUsernamePasswordPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -431,7 +431,7 @@ func Test_JwtClientAndUsernamePasswordPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -454,7 +454,7 @@ func Test_JwtClientAndUsernameAuthPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -477,7 +477,7 @@ func Test_JwtClientAndUsernameAuthPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -500,7 +500,7 @@ func Test_JwtClientAndUsernameImplicitPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -523,7 +523,7 @@ func Test_JwtClientAndUsernameImplicitPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -549,7 +549,7 @@ func Test_JwtClientAndClientTokenExchange(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -575,7 +575,7 @@ func Test_JwtMultiple(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -617,7 +617,7 @@ func Test_KCEntityResolutionNotFoundInferEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -651,7 +651,7 @@ func Test_KCEntityResolutionNotFoundInferClientId(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -684,7 +684,7 @@ func Test_KCEntityResolutionNotFoundNotInferUsername(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.Error(t, reserr)
assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp)
Expand Down
Loading
Loading