diff --git a/service/entityresolution/keycloak/v2/entity_resolution.go b/service/entityresolution/keycloak/v2/entity_resolution.go index e96a449108..dc03dbbbcd 100644 --- a/service/entityresolution/keycloak/v2/entity_resolution.go +++ b/service/entityresolution/keycloak/v2/entity_resolution.go @@ -86,7 +86,7 @@ func (s *EntityResolutionServiceV2) ResolveEntities(ctx context.Context, req *co s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err)) } - resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger) + resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache) return connect.NewResponse(&resp), err } @@ -99,7 +99,7 @@ func (s *EntityResolutionServiceV2) CreateEntityChainsFromTokens(ctx context.Con s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err)) } - resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger) + resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.idpConfig, connector, s.logger, s.svcCache) return connect.NewResponse(&resp), err } @@ -137,11 +137,12 @@ func CreateEntityChainsFromTokens( kcConfig Config, connector *Connector, logger *logger.Logger, + svcCache *cache.Cache, ) (entityresolutionV2.CreateEntityChainsFromTokensResponse, error) { entityChains := []*entity.EntityChain{} // for each token in the tokens form an entity chain for _, tok := range req.GetTokens() { - entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger) + entities, err := getEntitiesFromToken(ctx, kcConfig, connector, tok.GetJwt(), logger, svcCache) if err != nil { return entityresolutionV2.CreateEntityChainsFromTokensResponse{}, err } @@ -152,7 +153,7 @@ func CreateEntityChainsFromTokens( } func EntityResolution(ctx context.Context, - req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger, + req *entityresolutionV2.ResolveEntitiesRequest, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache, ) (entityresolutionV2.ResolveEntitiesResponse, error) { payload := req.GetEntities() // connector is now passed in @@ -176,9 +177,7 @@ func EntityResolution(ctx context.Context, ) clientID := ident.GetClientId() - clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{ - ClientID: &clientID, - }) + clients, err := retrieveClients(ctx, logger, clientID, kcConfig.Realm, svcCache, connector) if err != nil { logger.Error("error getting client info", slog.String("error", err.Error())) return entityresolutionV2.ResolveEntitiesResponse{}, @@ -229,7 +228,7 @@ func EntityResolution(ctx context.Context, } var jsonEntities []*structpb.Struct - users, err := connector.client.GetUsers(ctx, connector.token.AccessToken, kcConfig.Realm, getUserParams) + users, err := retrieveUsers(ctx, logger, getUserParams, kcConfig.Realm, svcCache, connector) switch { case err != nil: logger.ErrorContext(ctx, "error getting users", slog.Any("error", err)) @@ -247,12 +246,7 @@ func EntityResolution(ctx context.Context, logger.ErrorContext(ctx, "no user found", slog.Any("entity", ident)) if ident.GetEmailAddress() != "" { //nolint:nestif // this case has many possible outcomes to handle // try by group - groups, groupErr := connector.client.GetGroups( - ctx, - connector.token.AccessToken, - kcConfig.Realm, - gocloak.GetGroupsParams{Search: func() *string { t := ident.GetEmailAddress(); return &t }()}, - ) + groups, groupErr := retrieveGroupsByEmail(ctx, logger, ident.GetEmailAddress(), kcConfig.Realm, svcCache, connector) switch { case groupErr != nil: logger.Error("error getting group", slog.String("group", groupErr.Error())) @@ -261,7 +255,7 @@ func EntityResolution(ctx context.Context, case len(groups) == 1: logger.Info("group found for", slog.String("entity", ident.String())) group := groups[0] - expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger) + expandedRepresentations, exErr := expandGroup(ctx, *group.ID, connector, &kcConfig, logger, svcCache) if exErr != nil { return entityresolutionV2.ResolveEntitiesResponse{}, connect.NewError(connect.CodeNotFound, ErrNotFound) @@ -365,14 +359,13 @@ func typeToGenericJSONMap[Marshalable any](inputStruct Marshalable, logger *logg return genericMap, nil } -func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger) ([]*gocloak.User, error) { +func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kcConfig *Config, logger *logger.Logger, svcCache *cache.Cache) ([]*gocloak.User, error) { logger.Info("expanding group", slog.String("group_id", groupID)) var entityRepresentations []*gocloak.User - grp, err := kcConnector.client.GetGroup(ctx, kcConnector.token.AccessToken, kcConfig.Realm, groupID) + grp, err := retrieveGroupByID(ctx, logger, groupID, kcConfig.Realm, svcCache, kcConnector) if err == nil { - grpMembers, memberErr := kcConnector.client.GetGroupMembers(ctx, kcConnector.token.AccessToken, kcConfig.Realm, - *grp.ID, gocloak.GetGroupsParams{}) + grpMembers, memberErr := retrieveGroupMembers(ctx, logger, *grp.ID, kcConfig.Realm, svcCache, kcConnector) if memberErr == nil { logger.DebugContext(ctx, "adding members", @@ -393,7 +386,7 @@ func expandGroup(ctx context.Context, groupID string, kcConnector *Connector, kc return entityRepresentations, nil } -func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger) ([]*entity.Entity, error) { +func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Connector, jwtString string, logger *logger.Logger, svcCache *cache.Cache) ([]*entity.Entity, error) { token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false)) if err != nil { return nil, errors.New("error parsing jwt " + err.Error()) @@ -432,7 +425,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Conne // double check for service account if strings.HasPrefix(extractedValueUsernameCasted, serviceAccountUsernamePrefix) { - clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger) + clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger, svcCache) if err != nil { return nil, err } @@ -461,12 +454,10 @@ func getEntitiesFromToken(ctx context.Context, kcConfig Config, connector *Conne return entities, nil } -func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger) (string, error) { +func getServiceAccountClient(ctx context.Context, username string, kcConfig Config, connector *Connector, logger *logger.Logger, svcCache *cache.Cache) (string, error) { expectedClientName := strings.TrimPrefix(username, serviceAccountUsernamePrefix) - clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{ - ClientID: &expectedClientName, - }) + clients, err := retrieveClients(ctx, logger, expectedClientName, kcConfig.Realm, svcCache, connector) switch { case err != nil: logger.ErrorContext(ctx, "connector client error", slog.Any("error", err)) diff --git a/service/entityresolution/keycloak/v2/entity_resolution_test.go b/service/entityresolution/keycloak/v2/entity_resolution_test.go index 58e9261817..b4ecc07099 100644 --- a/service/entityresolution/keycloak/v2/entity_resolution_test.go +++ b/service/entityresolution/keycloak/v2/entity_resolution_test.go @@ -207,7 +207,7 @@ func Test_KCEntityResolutionByClientId(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) _ = json.NewEncoder(os.Stdout).Encode(&resp) @@ -236,7 +236,7 @@ func Test_KCEntityResolutionByEmail(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -276,7 +276,7 @@ func Test_KCEntityResolutionByUsername(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -321,7 +321,7 @@ func Test_KCEntityResolutionByGroupEmail(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -362,7 +362,7 @@ func Test_KCEntityResolutionNotFoundError(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.Error(t, reserr) assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp) @@ -385,7 +385,7 @@ func Test_JwtClientAndUsernameClientCredentials(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -408,7 +408,7 @@ func Test_JwtClientAndUsernamePasswordPub(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -431,7 +431,7 @@ func Test_JwtClientAndUsernamePasswordPriv(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -454,7 +454,7 @@ func Test_JwtClientAndUsernameAuthPub(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -477,7 +477,7 @@ func Test_JwtClientAndUsernameAuthPriv(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -500,7 +500,7 @@ func Test_JwtClientAndUsernameImplicitPub(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -523,7 +523,7 @@ func Test_JwtClientAndUsernameImplicitPriv(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -549,7 +549,7 @@ func Test_JwtClientAndClientTokenExchange(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -575,7 +575,7 @@ func Test_JwtMultiple(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := CreateEntityChainsFromTokens(t.Context(), &entityresolutionV2.CreateEntityChainsFromTokensRequest{Tokens: validBody}, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -617,7 +617,7 @@ func Test_KCEntityResolutionNotFoundInferEmail(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -651,7 +651,7 @@ func Test_KCEntityResolutionNotFoundInferClientId(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.NoError(t, reserr) @@ -684,7 +684,7 @@ func Test_KCEntityResolutionNotFoundNotInferUsername(t *testing.T) { token: &gocloak.JWT{AccessToken: "dummy_token"}, client: gocloak.NewClient(server.URL), } - resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger()) + resp, reserr := EntityResolution(t.Context(), &req, kcconfig, connector, logger.CreateTestLogger(), nil) require.Error(t, reserr) assert.Equal(t, &entityresolutionV2.ResolveEntitiesResponse{}, &resp) diff --git a/service/entityresolution/keycloak/v2/retrieve.go b/service/entityresolution/keycloak/v2/retrieve.go new file mode 100644 index 0000000000..e8890d16b5 --- /dev/null +++ b/service/entityresolution/keycloak/v2/retrieve.go @@ -0,0 +1,101 @@ +package keycloak + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/Nerzal/gocloak/v13" + "github.com/opentdf/platform/service/logger" + "github.com/opentdf/platform/service/pkg/cache" +) + +// Cache Key formats +// Client: {realm}::client::{clientid} +// User: {realm}::user::{emailaddress or username} +// Group: {realm}::group::{emailaddress or id} +// Group members: {realm}::group::{groupid}::members + +func retrieveClients(ctx context.Context, logger *logger.Logger, clientID string, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.Client, error) { + cacheKey := fmt.Sprintf("%s::client::%s", realm, clientID) + retrievalFunc := func() ([]*gocloak.Client, error) { + return connector.client.GetClients(ctx, connector.token.AccessToken, realm, gocloak.GetClientsParams{ + ClientID: &clientID, + }) + } + clients, err := retrieveWithKey[[]*gocloak.Client](ctx, cacheKey, svcCache, logger, retrievalFunc) + if err != nil { + return nil, err + } + return clients, nil +} + +func retrieveUsers(ctx context.Context, logger *logger.Logger, getUserParams gocloak.GetUsersParams, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.User, error) { + var cacheKey string + switch { + case getUserParams.Email != nil: + cacheKey = fmt.Sprintf("%s::user::%s", realm, *getUserParams.Email) + case getUserParams.Username != nil: + cacheKey = fmt.Sprintf("%s::user::%s", realm, *getUserParams.Username) + default: + return nil, errors.New("either email or username must be provided") + } + + retrievalFunc := func() ([]*gocloak.User, error) { + return connector.client.GetUsers(ctx, connector.token.AccessToken, realm, getUserParams) + } + return retrieveWithKey[[]*gocloak.User](ctx, cacheKey, svcCache, logger, retrievalFunc) +} + +func retrieveGroupsByEmail(ctx context.Context, logger *logger.Logger, groupEmail string, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.Group, error) { + cacheKey := fmt.Sprintf("%s::group::%s", realm, groupEmail) + retrievalFunc := func() ([]*gocloak.Group, error) { + return connector.client.GetGroups( + ctx, + connector.token.AccessToken, + realm, + gocloak.GetGroupsParams{Search: func() *string { t := groupEmail; return &t }()}, + ) + } + return retrieveWithKey[[]*gocloak.Group](ctx, cacheKey, svcCache, logger, retrievalFunc) +} + +func retrieveGroupByID(ctx context.Context, logger *logger.Logger, groupID string, realm string, svcCache *cache.Cache, connector *Connector) (*gocloak.Group, error) { + cacheKey := fmt.Sprintf("%s::group::%s", realm, groupID) + retrievalFunc := func() (*gocloak.Group, error) { + return connector.client.GetGroup(ctx, connector.token.AccessToken, realm, groupID) + } + return retrieveWithKey[*gocloak.Group](ctx, cacheKey, svcCache, logger, retrievalFunc) +} + +func retrieveGroupMembers(ctx context.Context, logger *logger.Logger, groupID string, realm string, svcCache *cache.Cache, connector *Connector) ([]*gocloak.User, error) { + cacheKey := fmt.Sprintf("%s::group::%s::members", realm, groupID) + retrievalFunc := func() ([]*gocloak.User, error) { + return connector.client.GetGroupMembers(ctx, connector.token.AccessToken, realm, groupID, gocloak.GetGroupsParams{}) + } + return retrieveWithKey[[]*gocloak.User](ctx, cacheKey, svcCache, logger, retrievalFunc) +} + +func retrieveWithKey[T any](ctx context.Context, cacheKey string, svcCache *cache.Cache, logger *logger.Logger, retrieveFunc func() (T, error)) (T, error) { + if svcCache != nil { + cachedData, err := svcCache.Get(ctx, cacheKey) + if err == nil { + if retrieved, ok := cachedData.(T); ok { + return retrieved, nil + } + logger.Error("cache data type assertion failed") + } else if !errors.Is(err, cache.ErrCacheMiss) { + var zero T + return zero, err + } + } + retrieved, err := retrieveFunc() + if svcCache != nil && err == nil { + cacheErr := svcCache.Set(ctx, cacheKey, retrieved, []string{}) + if cacheErr != nil { + logger.Error("error setting cache", slog.String("error", cacheErr.Error())) + } + } + return retrieved, err +} diff --git a/service/entityresolution/keycloak/v2/retrieve_test.go b/service/entityresolution/keycloak/v2/retrieve_test.go new file mode 100644 index 0000000000..d17ab3c1c9 --- /dev/null +++ b/service/entityresolution/keycloak/v2/retrieve_test.go @@ -0,0 +1,289 @@ +package keycloak + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Nerzal/gocloak/v13" + "github.com/opentdf/platform/service/logger" + "github.com/opentdf/platform/service/pkg/cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var cacheTime = 2 * time.Second + +func newTestCache(t *testing.T) (*cache.Manager, *cache.Cache) { + // Use a short expiration for test + cacheManager, err := cache.NewCacheManager(1000) + require.NoError(t, err, "Failed to create cache manager") + c, err := cacheManager.NewCache("test", logger.CreateTestLogger(), cache.Options{ + Expiration: cacheTime, + }) + require.NoError(t, err, "Failed to create test cache") + return cacheManager, c +} + +func TestRetrieveClients_CacheIntegration(t *testing.T) { + clientID := "myclient" + realm := "tdf" + cacheKey := fmt.Sprintf("%s::client::%s", realm, clientID) + clientsResp := []*gocloak.Client{{ID: gocloak.StringP(clientID)}} + + var called int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // slog.Info("Server called", "path", r.URL.Path, "query", r.URL.RawQuery) + called++ + assert.Equal(t, "/admin/realms/tdf/clients", r.URL.Path) + assert.Contains(t, r.URL.RawQuery, "clientId=myclient") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(clientsResp) + })) + defer server.Close() + + kc := &Connector{ + token: &gocloak.JWT{AccessToken: "dummy"}, + client: gocloak.NewClient(server.URL), + } + l := logger.CreateTestLogger() + cm, c := newTestCache(t) + defer cm.Close() // Ensure cache manager is closed after test + + // First call: cache miss, should hit server + got, err := retrieveClients(t.Context(), l, clientID, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, clientsResp, got) + assert.Equal(t, 1, called) + + // Wait + time.Sleep(200 * time.Millisecond) + + // Second call: cache hit, should NOT hit server + called = 0 + got2, err := retrieveClients(t.Context(), l, clientID, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, clientsResp, got2) + assert.Equal(t, 0, called, "server should not be called on cache hit") + + // Optionally, check cache directly + val, err := c.Get(t.Context(), cacheKey) + require.NoError(t, err) + assert.Equal(t, clientsResp, val) + + // Wait for cache expiration + time.Sleep(cacheTime) + // After expiration, cache should be empty + _, err = c.Get(t.Context(), cacheKey) + require.Error(t, err, "Cache should be empty after expiration") +} + +func TestRetrieveUsers_CacheIntegration(t *testing.T) { + email := "foo@bar.com" + realm := "tdf" + cacheKey := fmt.Sprintf("%s::user::%s", realm, email) + usersResp := []*gocloak.User{{Email: gocloak.StringP(email)}} + + var called int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called++ + assert.Equal(t, "/admin/realms/tdf/users", r.URL.Path) + assert.Contains(t, r.URL.RawQuery, "email=foo%40bar.com") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(usersResp) + })) + defer server.Close() + + kc := &Connector{ + token: &gocloak.JWT{AccessToken: "dummy"}, + client: gocloak.NewClient(server.URL), + } + l := logger.CreateTestLogger() + cm, c := newTestCache(t) + defer cm.Close() // Ensure cache manager is closed after test + + // First call: cache miss, should hit server + params := gocloak.GetUsersParams{Email: &email} + got, err := retrieveUsers(t.Context(), l, params, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, usersResp, got) + assert.Equal(t, 1, called) + + // Wait + time.Sleep(200 * time.Millisecond) + + called = 0 + // Second call: cache hit, should NOT hit server + got2, err := retrieveUsers(t.Context(), l, params, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, usersResp, got2) + assert.Equal(t, 0, called) + + // Optionally, check cache directly + val, err := c.Get(t.Context(), cacheKey) + require.NoError(t, err) + assert.Equal(t, usersResp, val) + + // Wait for cache expiration + time.Sleep(cacheTime) + // After expiration, cache should be empty + _, err = c.Get(t.Context(), cacheKey) + require.Error(t, err, "Cache should be empty after expiration") +} + +func TestRetrieveGroupsByEmail_CacheIntegration(t *testing.T) { + groupEmail := "group@bar.com" + realm := "tdf" + cacheKey := fmt.Sprintf("%s::group::%s", realm, groupEmail) + groupsResp := []*gocloak.Group{{ID: gocloak.StringP("gid")}} + + var called int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called++ + assert.Equal(t, "/admin/realms/tdf/groups", r.URL.Path) + assert.Contains(t, r.URL.RawQuery, "search=group%40bar.com") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(groupsResp) + })) + defer server.Close() + + kc := &Connector{ + token: &gocloak.JWT{AccessToken: "dummy"}, + client: gocloak.NewClient(server.URL), + } + l := logger.CreateTestLogger() + cm, c := newTestCache(t) + defer cm.Close() // Ensure cache manager is closed after test + + // First call: cache miss, should hit server + got, err := retrieveGroupsByEmail(t.Context(), l, groupEmail, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, groupsResp, got) + assert.Equal(t, 1, called) + + // Wait + time.Sleep(200 * time.Millisecond) + + // Second call: cache hit, should NOT hit server + called = 0 + got2, err := retrieveGroupsByEmail(t.Context(), l, groupEmail, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, groupsResp, got2) + assert.Equal(t, 0, called) + + // Optionally, check cache directly + val, err := c.Get(t.Context(), cacheKey) + require.NoError(t, err) + assert.Equal(t, groupsResp, val) + + // Wait for cache expiration + time.Sleep(cacheTime) + // After expiration, cache should be empty + _, err = c.Get(t.Context(), cacheKey) + require.Error(t, err, "Cache should be empty after expiration") +} + +func TestRetrieveGroupByID_CacheIntegration(t *testing.T) { + groupID := "gid" + realm := "tdf" + cacheKey := fmt.Sprintf("%s::group::%s", realm, groupID) + groupResp := &gocloak.Group{ID: gocloak.StringP(groupID)} + + var called int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called++ + assert.Equal(t, "/admin/realms/tdf/groups/gid", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(groupResp) + })) + defer server.Close() + + kc := &Connector{ + token: &gocloak.JWT{AccessToken: "dummy"}, + client: gocloak.NewClient(server.URL), + } + l := logger.CreateTestLogger() + cm, c := newTestCache(t) + defer cm.Close() // Ensure cache manager is closed after test + + // First call: cache miss, should hit server + got, err := retrieveGroupByID(t.Context(), l, groupID, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, groupResp, got) + assert.Equal(t, 1, called) + + // Wait + time.Sleep(200 * time.Millisecond) + + // Second call: cache hit, should NOT hit server + called = 0 + got2, err := retrieveGroupByID(t.Context(), l, groupID, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, groupResp, got2) + assert.Equal(t, 0, called) + + // Optionally, check cache directly + val, err := c.Get(t.Context(), cacheKey) + require.NoError(t, err) + assert.Equal(t, groupResp, val) + + // Wait for cache expiration + time.Sleep(cacheTime) + // After expiration, cache should be empty + _, err = c.Get(t.Context(), cacheKey) + require.Error(t, err, "Cache should be empty after expiration") +} + +func TestRetrieveGroupMembers_CacheIntegration(t *testing.T) { + groupID := "gid" + realm := "tdf" + cacheKey := fmt.Sprintf("%s::group::%s::members", realm, groupID) + membersResp := []*gocloak.User{{ID: gocloak.StringP("uid")}} + + var called int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called++ + assert.Equal(t, "/admin/realms/tdf/groups/gid/members", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(membersResp) + })) + defer server.Close() + + kc := &Connector{ + token: &gocloak.JWT{AccessToken: "dummy"}, + client: gocloak.NewClient(server.URL), + } + l := logger.CreateTestLogger() + cm, c := newTestCache(t) + defer cm.Close() // Ensure cache manager is closed after test + + // First call: cache miss, should hit server + got, err := retrieveGroupMembers(t.Context(), l, groupID, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, membersResp, got) + assert.Equal(t, 1, called) + + // Wait + time.Sleep(200 * time.Millisecond) + + // Second call: cache hit, should NOT hit server + called = 0 + got2, err := retrieveGroupMembers(t.Context(), l, groupID, realm, c, kc) + require.NoError(t, err) + assert.Equal(t, membersResp, got2) + assert.Equal(t, 0, called) + + // Optionally, check cache directly + val, err := c.Get(t.Context(), cacheKey) + require.NoError(t, err) + assert.Equal(t, membersResp, val) + + // Wait for cache expiration + time.Sleep(cacheTime) + // After expiration, cache should be empty + _, err = c.Get(t.Context(), cacheKey) + require.Error(t, err, "Cache should be empty after expiration") +} diff --git a/service/pkg/cache/cache.go b/service/pkg/cache/cache.go index 31b14e7566..e0e9bb31fd 100644 --- a/service/pkg/cache/cache.go +++ b/service/pkg/cache/cache.go @@ -88,14 +88,14 @@ func (c *Cache) Get(ctx context.Context, key string) (any, error) { val, err := c.manager.cache.Get(ctx, c.getKey(key)) if err != nil { // All errors are a cache miss in the gocache library. - c.logger.DebugContext(ctx, + c.logger.TraceContext(ctx, "cache miss", slog.Any("key", key), slog.Any("error", err), ) return nil, ErrCacheMiss } - c.logger.DebugContext(ctx, + c.logger.TraceContext(ctx, "cache hit", slog.Any("key", key), ) @@ -118,7 +118,7 @@ func (c *Cache) Set(ctx context.Context, key string, object any, tags []string) ) return err } - c.logger.DebugContext(ctx, "set cache", slog.Any("key", key)) + c.logger.TraceContext(ctx, "set cache", slog.Any("key", key)) return nil }