Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions service/entityresolution/keycloak/v2/entity_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *EntityResolutionServiceV2) ResolveEntities(ctx context.Context, req *co
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger)
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache)
return connect.NewResponse(&resp), err
}

Expand All @@ -99,7 +99,7 @@ func (s *EntityResolutionServiceV2) CreateEntityChainsFromTokens(ctx context.Con
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger)
resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache)
return connect.NewResponse(&resp), err
}

Expand Down Expand Up @@ -137,11 +137,12 @@ func CreateEntityChainsFromTokens(
kcConfig Config,
connector *Connector,
logger *logger.Logger,
svcCache *cache.Cache,
) (entityresolutionV2.CreateEntityChainsFromTokensResponse, error) {
entityChains := []*entity.EntityChain{}
// for each token in the tokens form an entity chain
for _, tok := range req.GetTokens() {
entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger)
entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger, svcCache)
if err != nil {
return entityresolutionV2.CreateEntityChainsFromTokensResponse{}, err
}
Expand All @@ -152,7 +153,7 @@ func CreateEntityChainsFromTokens(
}

func EntityResolution(ctx context.Context,
req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger,
req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache,
) (entityresolutionV2.ResolveEntitiesResponse, error) {
payload := req.GetEntities() // connector is now passed in

Expand All @@ -167,9 +168,7 @@ func EntityResolution(ctx context.Context,
case *entity.Entity_ClientId:
logger.Debug("looking up", slog.Any("type", ident.GetEntityType()), slog.String("client_id", ident.GetClientId()))
clientID := ident.GetClientId()
clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
ClientID: &clientID,
})
clients, err := retrieveClients(ctx, logger, clientID, kcConfig.Realm, svcCache, connector)
if err != nil {
logger.Error("error getting client info", slog.String("error", err.Error()))
return entityresolutionV2.ResolveEntitiesResponse{},
Expand Down Expand Up @@ -220,7 +219,7 @@ func EntityResolution(ctx context.Context,
}

var jsonEntities []*structpb.Struct
users, err := connector.client.GetUsers(ctx, connector.token.AccessToken, kcConfig.Realm, getUserParams)
users, err := retrieveUsers(ctx, logger, getUserParams, kcConfig.Realm, svcCache, connector)
switch {
case err != nil:
logger.Error(err.Error())
Expand All @@ -236,12 +235,7 @@ func EntityResolution(ctx context.Context,
logger.Error("no user found for", slog.Any("entity", ident))
if ident.GetEmailAddress() != "" { //nolint:nestif // this case has many possible outcomes to handle
// try by group
groups, groupErr := connector.client.GetGroups(
ctx,
connector.token.AccessToken,
kcConfig.Realm,
gocloak.GetGroupsParams{Search: func() *string { t := ident.GetEmailAddress(); return &t }()},
)
groups, groupErr := retrieveGroupsByEmail(ctx, logger, ident.GetEmailAddress(), kcConfig.Realm, svcCache, connector)
switch {
case groupErr != nil:
logger.Error("error getting group", slog.String("group", groupErr.Error()))
Expand All @@ -250,7 +244,7 @@ func EntityResolution(ctx context.Context,
case len(groups) == 1:
logger.Info("group found for", slog.String("entity", ident.String()))
group := groups[0]
expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger)
expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger, svcCache)
if exErr != nil {
return entityresolutionV2.ResolveEntitiesResponse{},
connect.NewError(connect.CodeNotFound, ErrNotFound)
Expand Down Expand Up @@ -353,14 +347,13 @@ func typeToGenericJSONMap[Marshalable any](inputStruct Marshalable, logger *logg
return genericMap, nil
}

func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger) ([]*gocloak.User, error) {
func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger, svcCache *cache.Cache) ([]*gocloak.User, error) {
logger.Info("expanding group", slog.String("groupID", groupID))
var entityRepresentations []*gocloak.User

grp, err := kcConnector.client.GetGroup(ctx, kcConnector.token.AccessToken, kcConfig.Realm, groupID)
grp, err := retrieveGroupByID(ctx, logger, groupID, kcConfig.Realm, svcCache, kcConnector)
if err == nil {
grpMembers, memberErr := kcConnector.client.GetGroupMembers(ctx, kcConnector.token.AccessToken, kcConfig.Realm,
*grp.ID, gocloak.GetGroupsParams{})
grpMembers, memberErr := retrieveGroupMembers(ctx, logger, *grp.ID, kcConfig.Realm, svcCache, kcConnector)
if memberErr == nil {
logger.Debug("adding members", slog.Int("amount", len(grpMembers)), slog.String("from group", *grp.Name))
for i := 0; i < len(grpMembers); i++ {
Expand All @@ -377,7 +370,7 @@ func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kc
return entityRepresentations, nil
}

func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger) ([]*entity.Entity, error) {
func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger, svcCache *cache.Cache) ([]*entity.Entity, error) {
token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false))
if err != nil {
return nil, errors.New("error parsing jwt " + err.Error())
Expand Down Expand Up @@ -416,7 +409,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Conne

// double check for service account
if strings.HasPrefix(extractedValueUsernameCasted, serviceAccountUsernamePrefix) {
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger)
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger, svcCache)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -445,12 +438,10 @@ func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Conne
return entities, nil
}

func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger) (string, error) {
func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache) (string, error) {
expectedClientName := strings.TrimPrefix(username, serviceAccountUsernamePrefix)

clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
ClientID: &expectedClientName,
})
clients, err := retrieveClients(ctx, logger, expectedClientName, kcConfig.Realm, svcCache, connector)
switch {
case err != nil:
logger.Error(err.Error())
Expand Down
34 changes: 17 additions & 17 deletions service/entityresolution/keycloak/v2/entity_resolution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func Test_KCEntityResolutionByClientId(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)
_ = json.NewEncoder(os.Stdout).Encode(&resp)
Expand Down Expand Up @@ -236,7 +236,7 @@ func Test_KCEntityResolutionByEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -276,7 +276,7 @@ func Test_KCEntityResolutionByUsername(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -321,7 +321,7 @@ func Test_KCEntityResolutionByGroupEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -362,7 +362,7 @@ func Test_KCEntityResolutionNotFoundError(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.Error(t, reserr)
assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp)
Expand All @@ -385,7 +385,7 @@ func Test_JwtClientAndUsernameClientCredentials(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -408,7 +408,7 @@ func Test_JwtClientAndUsernamePasswordPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -431,7 +431,7 @@ func Test_JwtClientAndUsernamePasswordPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -454,7 +454,7 @@ func Test_JwtClientAndUsernameAuthPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -477,7 +477,7 @@ func Test_JwtClientAndUsernameAuthPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -500,7 +500,7 @@ func Test_JwtClientAndUsernameImplicitPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -523,7 +523,7 @@ func Test_JwtClientAndUsernameImplicitPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -549,7 +549,7 @@ func Test_JwtClientAndClientTokenExchange(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -575,7 +575,7 @@ func Test_JwtMultiple(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -617,7 +617,7 @@ func Test_KCEntityResolutionNotFoundInferEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -651,7 +651,7 @@ func Test_KCEntityResolutionNotFoundInferClientId(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -684,7 +684,7 @@ func Test_KCEntityResolutionNotFoundNotInferUsername(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.Error(t, reserr)
assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp)
Expand Down
Loading
Loading