Skip to content

Commit 7ed3369

Browse files
strantalisgithub-actions[bot]
authored andcommitted
fix: only request a token when near expiration (#2370)
### Proposed Changes * ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions (cherry picked from commit 556d95e)
1 parent 7e32748 commit 7ed3369

File tree

4 files changed

+476
-209
lines changed

4 files changed

+476
-209
lines changed

service/entityresolution/keycloak/entity_resolution.go

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ import (
88
"log/slog"
99
"strconv"
1010
"strings"
11+
"sync"
12+
"time"
1113

1214
"connectrpc.com/connect"
1315
"github.com/Nerzal/gocloak/v13"
16+
"github.com/creasty/defaults"
1417
"github.com/go-viper/mapstructure/v2"
18+
1519
"github.com/lestrrat-go/jwx/v2/jwt"
1620
"github.com/opentdf/platform/protocol/go/authorization"
1721
"github.com/opentdf/platform/protocol/go/entityresolution"
@@ -43,6 +47,8 @@ type KeycloakEntityResolutionService struct { //nolint:revive // Too late! Alrea
4347
idpConfig KeycloakConfig
4448
logger *logger.Logger
4549
trace.Tracer
50+
connector *KeyCloakConnector
51+
connectorMu sync.Mutex
4652
}
4753

4854
type KeycloakConfig struct { //nolint:revive // yeah but what if we want to embed multiple configs?
@@ -53,10 +59,16 @@ type KeycloakConfig struct { //nolint:revive // yeah but what if we want to embe
5359
LegacyKeycloak bool `mapstructure:"legacykeycloak" json:"legacykeycloak" default:"false"`
5460
SubGroups bool `mapstructure:"subgroups" json:"subgroups" default:"false"`
5561
InferID InferredIdentityConfig `mapstructure:"inferid,omitempty" json:"inferid,omitempty"`
62+
TokenBuffer time.Duration `mapstructure:"token_buffer_seconds" json:"token_buffer_seconds" default:"120s"`
5663
}
5764

5865
func RegisterKeycloakERS(config config.ServiceConfig, logger *logger.Logger) (*KeycloakEntityResolutionService, serviceregistry.HandlerServer) {
5966
var inputIdpConfig KeycloakConfig
67+
68+
if err := defaults.Set(&inputIdpConfig); err != nil {
69+
panic(err)
70+
}
71+
6072
if err := mapstructure.Decode(config, &inputIdpConfig); err != nil {
6173
panic(err)
6274
}
@@ -65,19 +77,31 @@ func RegisterKeycloakERS(config config.ServiceConfig, logger *logger.Logger) (*K
6577
return keycloakSVC, nil
6678
}
6779

68-
func (s KeycloakEntityResolutionService) ResolveEntities(ctx context.Context, req *connect.Request[entityresolution.ResolveEntitiesRequest]) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) {
80+
func (s *KeycloakEntityResolutionService) ResolveEntities(ctx context.Context, req *connect.Request[entityresolution.ResolveEntitiesRequest]) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) {
6981
ctx, span := s.Tracer.Start(ctx, "ResolveEntities")
7082
defer span.End()
7183

72-
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, s.logger)
84+
connector, err := s.getConnector(ctx, s.idpConfig.TokenBuffer)
85+
if err != nil {
86+
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
87+
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
88+
}
89+
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger)
90+
7391
return connect.NewResponse(&resp), err
7492
}
7593

76-
func (s KeycloakEntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *connect.Request[entityresolution.CreateEntityChainFromJwtRequest]) (*connect.Response[entityresolution.CreateEntityChainFromJwtResponse], error) {
94+
func (s *KeycloakEntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *connect.Request[entityresolution.CreateEntityChainFromJwtRequest]) (*connect.Response[entityresolution.CreateEntityChainFromJwtResponse], error) {
7795
ctx, span := s.Tracer.Start(ctx, "CreateEntityChainFromJwt")
7896
defer span.End()
7997

80-
resp, err := CreateEntityChainFromJwt(ctx, req.Msg, s.idpConfig, s.logger)
98+
connector, err := s.getConnector(ctx, s.idpConfig.TokenBuffer)
99+
if err != nil {
100+
s.logger.ErrorContext(ctx, "error getting keycloak connector", slog.String("error", err.Error()))
101+
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
102+
}
103+
resp, err := CreateEntityChainFromJwt(ctx, req.Msg, s.idpConfig, connector, s.logger)
104+
81105
return connect.NewResponse(&resp), err
82106
}
83107

@@ -104,20 +128,22 @@ type EntityImpliedFrom struct {
104128
}
105129

106130
type KeyCloakConnector struct { //nolint:revive // Too late! Already exported
107-
token *gocloak.JWT
108-
client *gocloak.GoCloak
131+
token *gocloak.JWT
132+
client *gocloak.GoCloak
133+
expiresAt time.Time
109134
}
110135

111136
func CreateEntityChainFromJwt(
112137
ctx context.Context,
113138
req *entityresolution.CreateEntityChainFromJwtRequest,
114139
kcConfig KeycloakConfig,
140+
connector *KeyCloakConnector,
115141
logger *logger.Logger,
116142
) (entityresolution.CreateEntityChainFromJwtResponse, error) {
117143
entityChains := []*authorization.EntityChain{}
118144
// for each token in the tokens form an entity chain
119145
for _, tok := range req.GetTokens() {
120-
entities, err := getEntitiesFromToken(ctx, kcConfig, tok.GetJwt(), logger)
146+
entities, err := getEntitiesFromToken(ctx, kcConfig, tok.GetJwt(), connector, logger)
121147
if err != nil {
122148
return entityresolution.CreateEntityChainFromJwtResponse{}, err
123149
}
@@ -128,13 +154,8 @@ func CreateEntityChainFromJwt(
128154
}
129155

130156
func EntityResolution(ctx context.Context,
131-
req *entityresolution.ResolveEntitiesRequest, kcConfig KeycloakConfig, logger *logger.Logger,
157+
req *entityresolution.ResolveEntitiesRequest, kcConfig KeycloakConfig, connector *KeyCloakConnector, logger *logger.Logger,
132158
) (entityresolution.ResolveEntitiesResponse, error) {
133-
connector, err := getKCClient(ctx, kcConfig, logger)
134-
if err != nil {
135-
return entityresolution.ResolveEntitiesResponse{},
136-
connect.NewError(connect.CodeInternal, ErrCreationFailed)
137-
}
138159
payload := req.GetEntities()
139160

140161
var resolvedEntities []*entityresolution.EntityRepresentation
@@ -334,35 +355,6 @@ func typeToGenericJSONMap[Marshalable any](inputStruct Marshalable, logger *logg
334355
return genericMap, nil
335356
}
336357

337-
func getKCClient(ctx context.Context, kcConfig KeycloakConfig, logger *logger.Logger) (*KeyCloakConnector, error) {
338-
var client *gocloak.GoCloak
339-
if kcConfig.LegacyKeycloak {
340-
logger.Warn("using legacy connection mode for Keycloak < 17.x.x")
341-
client = gocloak.NewClient(kcConfig.URL)
342-
} else {
343-
client = gocloak.NewClient(kcConfig.URL, gocloak.SetAuthAdminRealms("admin/realms"), gocloak.SetAuthRealms("realms"))
344-
}
345-
// If needed, ability to disable tls checks for testing
346-
// restyClient := client.RestyClient()
347-
// restyClient.SetDebug(true)
348-
// restyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
349-
// client.SetRestyClient(restyClient)
350-
351-
// For debugging
352-
// logger.Debug(kcConfig.ClientID)
353-
// logger.Debug(kcConfig.ClientSecret)
354-
// logger.Debug(kcConfig.URL)
355-
// logger.Debug(kcConfig.Realm)
356-
token, err := client.LoginClient(ctx, kcConfig.ClientID, kcConfig.ClientSecret, kcConfig.Realm)
357-
if err != nil {
358-
logger.Error("error connecting to keycloak!", slog.String("error", err.Error()))
359-
return nil, err
360-
}
361-
keycloakConnector := KeyCloakConnector{token: token, client: client}
362-
363-
return &keycloakConnector, nil
364-
}
365-
366358
func expandGroup(ctx context.Context, groupID string, kcConnector *KeyCloakConnector, kcConfig *KeycloakConfig, logger *logger.Logger) ([]*gocloak.User, error) {
367359
logger.Info("expanding group", slog.String("groupID", groupID))
368360
var entityRepresentations []*gocloak.User
@@ -387,7 +379,7 @@ func expandGroup(ctx context.Context, groupID string, kcConnector *KeyCloakConne
387379
return entityRepresentations, nil
388380
}
389381

390-
func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtString string, logger *logger.Logger) ([]*authorization.Entity, error) {
382+
func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtString string, connector *KeyCloakConnector, logger *logger.Logger) ([]*authorization.Entity, error) {
391383
token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false))
392384
if err != nil {
393385
return nil, errors.New("error parsing jwt " + err.Error())
@@ -426,7 +418,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtStrin
426418

427419
// double check for service account
428420
if strings.HasPrefix(extractedValueUsernameCasted, serviceAccountUsernamePrefix) {
429-
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, logger)
421+
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger)
430422
if err != nil {
431423
return nil, err
432424
}
@@ -455,11 +447,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtStrin
455447
return entities, nil
456448
}
457449

458-
func getServiceAccountClient(ctx context.Context, username string, kcConfig KeycloakConfig, logger *logger.Logger) (string, error) {
459-
connector, err := getKCClient(ctx, kcConfig, logger)
460-
if err != nil {
461-
return "", err
462-
}
450+
func getServiceAccountClient(ctx context.Context, username string, kcConfig KeycloakConfig, connector *KeyCloakConnector, logger *logger.Logger) (string, error) {
463451
expectedClientName := strings.TrimPrefix(username, serviceAccountUsernamePrefix)
464452

465453
clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
@@ -494,3 +482,39 @@ func entityToStructPb(ident *authorization.Entity) (*structpb.Struct, error) {
494482
}
495483
return &entityStruct, nil
496484
}
485+
486+
// getConnector ensures a valid Keycloak connector is available, refreshing the token if necessary.
487+
func (s *KeycloakEntityResolutionService) getConnector(ctx context.Context, tokenBuffer time.Duration) (*KeyCloakConnector, error) {
488+
s.connectorMu.Lock()
489+
defer s.connectorMu.Unlock()
490+
491+
// Refresh token if it's nil, expired, or about to expire.
492+
493+
if s.connector == nil || s.connector.token == nil || time.Now().After(s.connector.expiresAt.Add(-tokenBuffer)) {
494+
s.logger.InfoContext(ctx, "Keycloak connector is nil or token expired/expiring soon. Fetching new token.")
495+
496+
var gocloakClient *gocloak.GoCloak
497+
if s.idpConfig.LegacyKeycloak {
498+
s.logger.WarnContext(ctx, "Using legacy connection mode for Keycloak < 17.x.x")
499+
gocloakClient = gocloak.NewClient(s.idpConfig.URL)
500+
} else {
501+
gocloakClient = gocloak.NewClient(s.idpConfig.URL, gocloak.SetAuthAdminRealms("admin/realms"), gocloak.SetAuthRealms("realms"))
502+
}
503+
504+
token, err := gocloakClient.LoginClient(ctx, s.idpConfig.ClientID, s.idpConfig.ClientSecret, s.idpConfig.Realm)
505+
if err != nil {
506+
s.logger.ErrorContext(ctx, "Error connecting to Keycloak or logging in", slog.String("error", err.Error()))
507+
return nil, fmt.Errorf("failed to login to Keycloak: %w", err)
508+
}
509+
510+
s.connector = &KeyCloakConnector{
511+
token: token,
512+
client: gocloakClient,
513+
expiresAt: time.Now().Add(time.Duration(token.ExpiresIn) * time.Second),
514+
}
515+
s.logger.InfoContext(ctx, "Successfully fetched new Keycloak token.", "expires_in_seconds", token.ExpiresIn)
516+
} else {
517+
s.logger.DebugContext(ctx, "Using existing Keycloak token.")
518+
}
519+
return s.connector, nil
520+
}

0 commit comments

Comments
 (0)