Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions service/entityresolution/keycloak/v2/entity_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *EntityResolutionServiceV2) ResolveEntities(ctx context.Context, req *co
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger)
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache)
return connect.NewResponse(&resp), err
}

Expand All @@ -99,7 +99,7 @@ func (s *EntityResolutionServiceV2) CreateEntityChainsFromTokens(ctx context.Con
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger)
resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache)
return connect.NewResponse(&resp), err
}

Expand Down Expand Up @@ -137,11 +137,12 @@ func CreateEntityChainsFromTokens(
kcConfig Config,
connector *Connector,
logger *logger.Logger,
svcCache *cache.Cache,
) (entityresolutionV2.CreateEntityChainsFromTokensResponse, error) {
entityChains := []*entity.EntityChain{}
// for each token in the tokens form an entity chain
for _, tok := range req.GetTokens() {
entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger)
entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger, svcCache)
if err != nil {
return entityresolutionV2.CreateEntityChainsFromTokensResponse{}, err
}
Expand All @@ -152,7 +153,7 @@ func CreateEntityChainsFromTokens(
}

func EntityResolution(ctx context.Context,
req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger,
req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache,
) (entityresolutionV2.ResolveEntitiesResponse, error) {
payload := req.GetEntities() // connector is now passed in

Expand All @@ -176,9 +177,7 @@ func EntityResolution(ctx context.Context,
)

clientID := ident.GetClientId()
clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
ClientID: &clientID,
})
clients, err := retrieveClients(ctx, logger, clientID, kcConfig.Realm, svcCache, connector)
if err != nil {
logger.Error("error getting client info", slog.String("error", err.Error()))
return entityresolutionV2.ResolveEntitiesResponse{},
Expand Down Expand Up @@ -229,7 +228,7 @@ func EntityResolution(ctx context.Context,
}

var jsonEntities []*structpb.Struct
users, err := connector.client.GetUsers(ctx, connector.token.AccessToken, kcConfig.Realm, getUserParams)
users, err := retrieveUsers(ctx, logger, getUserParams, kcConfig.Realm, svcCache, connector)
switch {
case err != nil:
logger.ErrorContext(ctx, "error getting users", slog.Any("error", err))
Expand All @@ -247,12 +246,7 @@ func EntityResolution(ctx context.Context,
logger.ErrorContext(ctx, "no user found", slog.Any("entity", ident))
if ident.GetEmailAddress() != "" { //nolint:nestif // this case has many possible outcomes to handle
// try by group
groups, groupErr := connector.client.GetGroups(
ctx,
connector.token.AccessToken,
kcConfig.Realm,
gocloak.GetGroupsParams{Search: func() *string { t := ident.GetEmailAddress(); return &t }()},
)
groups, groupErr := retrieveGroupsByEmail(ctx, logger, ident.GetEmailAddress(), kcConfig.Realm, svcCache, connector)
switch {
case groupErr != nil:
logger.Error("error getting group", slog.String("group", groupErr.Error()))
Expand All @@ -261,7 +255,7 @@ func EntityResolution(ctx context.Context,
case len(groups) == 1:
logger.Info("group found for", slog.String("entity", ident.String()))
group := groups[0]
expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger)
expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger, svcCache)
if exErr != nil {
return entityresolutionV2.ResolveEntitiesResponse{},
connect.NewError(connect.CodeNotFound, ErrNotFound)
Expand Down Expand Up @@ -365,14 +359,13 @@ func typeToGenericJSONMap[Marshalable any](inputStruct Marshalable, logger *logg
return genericMap, nil
}

func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger) ([]*gocloak.User, error) {
func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger, svcCache *cache.Cache) ([]*gocloak.User, error) {
logger.Info("expanding group", slog.String("group_id", groupID))
var entityRepresentations []*gocloak.User

grp, err := kcConnector.client.GetGroup(ctx, kcConnector.token.AccessToken, kcConfig.Realm, groupID)
grp, err := retrieveGroupByID(ctx, logger, groupID, kcConfig.Realm, svcCache, kcConnector)
if err == nil {
grpMembers, memberErr := kcConnector.client.GetGroupMembers(ctx, kcConnector.token.AccessToken, kcConfig.Realm,
*grp.ID, gocloak.GetGroupsParams{})
grpMembers, memberErr := retrieveGroupMembers(ctx, logger, *grp.ID, kcConfig.Realm, svcCache, kcConnector)
if memberErr == nil {
logger.DebugContext(ctx,
"adding members",
Expand All @@ -393,7 +386,7 @@ func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kc
return entityRepresentations, nil
}

func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger) ([]*entity.Entity, error) {
func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger, svcCache *cache.Cache) ([]*entity.Entity, error) {
token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false))
if err != nil {
return nil, errors.New("error parsing jwt " + err.Error())
Expand Down Expand Up @@ -432,7 +425,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Conne

// double check for service account
if strings.HasPrefix(extractedValueUsernameCasted, serviceAccountUsernamePrefix) {
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger)
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger, svcCache)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -461,12 +454,10 @@ func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Conne
return entities, nil
}

func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger) (string, error) {
func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache) (string, error) {
expectedClientName := strings.TrimPrefix(username, serviceAccountUsernamePrefix)

clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
ClientID: &expectedClientName,
})
clients, err := retrieveClients(ctx, logger, expectedClientName, kcConfig.Realm, svcCache, connector)
switch {
case err != nil:
logger.ErrorContext(ctx, "connector client error", slog.Any("error", err))
Expand Down
34 changes: 17 additions & 17 deletions service/entityresolution/keycloak/v2/entity_resolution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func Test_KCEntityResolutionByClientId(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)
_ = json.NewEncoder(os.Stdout).Encode(&resp)
Expand Down Expand Up @@ -236,7 +236,7 @@ func Test_KCEntityResolutionByEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -276,7 +276,7 @@ func Test_KCEntityResolutionByUsername(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -321,7 +321,7 @@ func Test_KCEntityResolutionByGroupEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -362,7 +362,7 @@ func Test_KCEntityResolutionNotFoundError(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.Error(t, reserr)
assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp)
Expand All @@ -385,7 +385,7 @@ func Test_JwtClientAndUsernameClientCredentials(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -408,7 +408,7 @@ func Test_JwtClientAndUsernamePasswordPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -431,7 +431,7 @@ func Test_JwtClientAndUsernamePasswordPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -454,7 +454,7 @@ func Test_JwtClientAndUsernameAuthPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -477,7 +477,7 @@ func Test_JwtClientAndUsernameAuthPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -500,7 +500,7 @@ func Test_JwtClientAndUsernameImplicitPub(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -523,7 +523,7 @@ func Test_JwtClientAndUsernameImplicitPriv(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -549,7 +549,7 @@ func Test_JwtClientAndClientTokenExchange(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand All @@ -575,7 +575,7 @@ func Test_JwtMultiple(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -617,7 +617,7 @@ func Test_KCEntityResolutionNotFoundInferEmail(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -651,7 +651,7 @@ func Test_KCEntityResolutionNotFoundInferClientId(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.NoError(t, reserr)

Expand Down Expand Up @@ -684,7 +684,7 @@ func Test_KCEntityResolutionNotFoundNotInferUsername(t *testing.T) {
token: &gocloak.JWT{AccessToken: "dummy_token"},
client: gocloak.NewClient(server.URL),
}
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger())
resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil)

require.Error(t, reserr)
assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp)
Expand Down
101 changes: 101 additions & 0 deletions service/entityresolution/keycloak/v2/retrieve.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package keycloak

import (
"context"
"errors"
"fmt"
"log/slog"

"github.com/Nerzal/gocloak/v13"
"github.com/opentdf/platform/service/logger"
"github.com/opentdf/platform/service/pkg/cache"
)

// Cache Key formats
// Client: {realm}::client::{clientid}
// User: {realm}::user::{emailaddress or username}
// Group: {realm}::group::{emailaddress or id}
// Group members: {realm}::group::{groupid}::members

func retrieveClients(ctx context.Context, logger *logger.Logger, clientID string, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.Client, error) {
cacheKey := fmt.Sprintf("%s::client::%s", realm, clientID)
retrievalFunc := func() ([]*gocloak.Client, error) {
return connector.client.GetClients(ctx, connector.token.AccessToken, realm, gocloak.GetClientsParams{
ClientID: &clientID,
})
}
clients, err := retrieveWithKey[[]*gocloak.Client](ctx, cacheKey, svcCache, logger, retrievalFunc)
if err != nil {
return nil, err
}
return clients, nil
}

func retrieveUsers(ctx context.Context, logger *logger.Logger, getUserParams gocloak.GetUsersParams, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.User, error) {
var cacheKey string
switch {
case getUserParams.Email != nil:
cacheKey = fmt.Sprintf("%s::user::%s", realm, *getUserParams.Email)
case getUserParams.Username != nil:
cacheKey = fmt.Sprintf("%s::user::%s", realm, *getUserParams.Username)
default:
return nil, errors.New("either email or username must be provided")
}

retrievalFunc := func() ([]*gocloak.User, error) {
return connector.client.GetUsers(ctx, connector.token.AccessToken, realm, getUserParams)
}
return retrieveWithKey[[]*gocloak.User](ctx, cacheKey, svcCache, logger, retrievalFunc)
}

func retrieveGroupsByEmail(ctx context.Context, logger *logger.Logger, groupEmail string, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.Group, error) {
cacheKey := fmt.Sprintf("%s::group::%s", realm, groupEmail)
retrievalFunc := func() ([]*gocloak.Group, error) {
return connector.client.GetGroups(
ctx,
connector.token.AccessToken,
realm,
gocloak.GetGroupsParams{Search: func() *string { t := groupEmail; return &t }()},
)
}
return retrieveWithKey[[]*gocloak.Group](ctx, cacheKey, svcCache, logger, retrievalFunc)
}

func retrieveGroupByID(ctx context.Context, logger *logger.Logger, groupID string, realm string, svcCache *cache.Cache, connector *Connector) (*gocloak.Group, error) {
cacheKey := fmt.Sprintf("%s::group::%s", realm, groupID)
retrievalFunc := func() (*gocloak.Group, error) {
return connector.client.GetGroup(ctx, connector.token.AccessToken, realm, groupID)
}
return retrieveWithKey[*gocloak.Group](ctx, cacheKey, svcCache, logger, retrievalFunc)
}

func retrieveGroupMembers(ctx context.Context, logger *logger.Logger, groupID string, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.User, error) {
cacheKey := fmt.Sprintf("%s::group::%s::members", realm, groupID)
retrievalFunc := func() ([]*gocloak.User, error) {
return connector.client.GetGroupMembers(ctx, connector.token.AccessToken, realm, groupID, gocloak.GetGroupsParams{})
}
return retrieveWithKey[[]*gocloak.User](ctx, cacheKey, svcCache, logger, retrievalFunc)
}

func retrieveWithKey[T any](ctx context.Context, cacheKey string, svcCache *cache.Cache, logger *logger.Logger, retrieveFunc func() (T, error)) (T, error) {
if svcCache != nil {
cachedData, err := svcCache.Get(ctx, cacheKey)
if err == nil {
if retrieved, ok := cachedData.(T); ok {
return retrieved, nil
}
logger.Error("cache data type assertion failed")
} else if !errors.Is(err, cache.ErrCacheMiss) {
var zero T
return zero, err
}
}
retrieved, err := retrieveFunc()
if svcCache != nil && err == nil {
cacheErr := svcCache.Set(ctx, cacheKey, retrieved, []string{})
if cacheErr != nil {
logger.Error("error setting cache", slog.String("error", cacheErr.Error()))
}
}
return retrieved, err
}
Loading
Loading