@@ -8,10 +8,14 @@ import (
88 "log/slog"
99 "strconv"
1010 "strings"
11+ "sync"
12+ "time"
1113
1214 "connectrpc.com/connect"
1315 "github.com/Nerzal/gocloak/v13"
16+ "github.com/creasty/defaults"
1417 "github.com/go-viper/mapstructure/v2"
18+
1519 "github.com/lestrrat-go/jwx/v2/jwt"
1620 "github.com/opentdf/platform/protocol/go/authorization"
1721 "github.com/opentdf/platform/protocol/go/entityresolution"
@@ -43,6 +47,8 @@ type KeycloakEntityResolutionService struct { //nolint:revive // Too late! Alrea
4347 idpConfig KeycloakConfig
4448 logger * logger.Logger
4549 trace.Tracer
50+ connector * KeyCloakConnector
51+ connectorMu sync.Mutex
4652}
4753
4854type KeycloakConfig struct { //nolint:revive // yeah but what if we want to embed multiple configs?
@@ -53,10 +59,16 @@ type KeycloakConfig struct { //nolint:revive // yeah but what if we want to embe
5359 LegacyKeycloak bool `mapstructure:"legacykeycloak" json:"legacykeycloak" default:"false"`
5460 SubGroups bool `mapstructure:"subgroups" json:"subgroups" default:"false"`
5561 InferID InferredIdentityConfig `mapstructure:"inferid,omitempty" json:"inferid,omitempty"`
62+ TokenBuffer time.Duration `mapstructure:"token_buffer_seconds" json:"token_buffer_seconds" default:"120s"`
5663}
5764
5865func RegisterKeycloakERS (config config.ServiceConfig , logger * logger.Logger ) (* KeycloakEntityResolutionService , serviceregistry.HandlerServer ) {
5966 var inputIdpConfig KeycloakConfig
67+
68+ if err := defaults .Set (& inputIdpConfig ); err != nil {
69+ panic (err )
70+ }
71+
6072 if err := mapstructure .Decode (config , & inputIdpConfig ); err != nil {
6173 panic (err )
6274 }
@@ -65,19 +77,31 @@ func RegisterKeycloakERS(config config.ServiceConfig, logger *logger.Logger) (*K
6577 return keycloakSVC , nil
6678}
6779
68- func (s KeycloakEntityResolutionService ) ResolveEntities (ctx context.Context , req * connect.Request [entityresolution.ResolveEntitiesRequest ]) (* connect.Response [entityresolution.ResolveEntitiesResponse ], error ) {
80+ func (s * KeycloakEntityResolutionService ) ResolveEntities (ctx context.Context , req * connect.Request [entityresolution.ResolveEntitiesRequest ]) (* connect.Response [entityresolution.ResolveEntitiesResponse ], error ) {
6981 ctx , span := s .Tracer .Start (ctx , "ResolveEntities" )
7082 defer span .End ()
7183
72- resp , err := EntityResolution (ctx , req .Msg , s .idpConfig , s .logger )
84+ connector , err := s .getConnector (ctx , s .idpConfig .TokenBuffer )
85+ if err != nil {
86+ s .logger .ErrorContext (ctx , "error getting keycloak connector" , slog .String ("error" , err .Error ()))
87+ return nil , connect .NewError (connect .CodeInternal , fmt .Errorf ("%w: %w" , ErrCreationFailed , err ))
88+ }
89+ resp , err := EntityResolution (ctx , req .Msg , s .idpConfig , connector , s .logger )
90+
7391 return connect .NewResponse (& resp ), err
7492}
7593
76- func (s KeycloakEntityResolutionService ) CreateEntityChainFromJwt (ctx context.Context , req * connect.Request [entityresolution.CreateEntityChainFromJwtRequest ]) (* connect.Response [entityresolution.CreateEntityChainFromJwtResponse ], error ) {
94+ func (s * KeycloakEntityResolutionService ) CreateEntityChainFromJwt (ctx context.Context , req * connect.Request [entityresolution.CreateEntityChainFromJwtRequest ]) (* connect.Response [entityresolution.CreateEntityChainFromJwtResponse ], error ) {
7795 ctx , span := s .Tracer .Start (ctx , "CreateEntityChainFromJwt" )
7896 defer span .End ()
7997
80- resp , err := CreateEntityChainFromJwt (ctx , req .Msg , s .idpConfig , s .logger )
98+ connector , err := s .getConnector (ctx , s .idpConfig .TokenBuffer )
99+ if err != nil {
100+ s .logger .ErrorContext (ctx , "error getting keycloak connector" , slog .String ("error" , err .Error ()))
101+ return nil , connect .NewError (connect .CodeInternal , fmt .Errorf ("%w: %w" , ErrCreationFailed , err ))
102+ }
103+ resp , err := CreateEntityChainFromJwt (ctx , req .Msg , s .idpConfig , connector , s .logger )
104+
81105 return connect .NewResponse (& resp ), err
82106}
83107
@@ -104,20 +128,22 @@ type EntityImpliedFrom struct {
104128}
105129
106130type KeyCloakConnector struct { //nolint:revive // Too late! Already exported
107- token * gocloak.JWT
108- client * gocloak.GoCloak
131+ token * gocloak.JWT
132+ client * gocloak.GoCloak
133+ expiresAt time.Time
109134}
110135
111136func CreateEntityChainFromJwt (
112137 ctx context.Context ,
113138 req * entityresolution.CreateEntityChainFromJwtRequest ,
114139 kcConfig KeycloakConfig ,
140+ connector * KeyCloakConnector ,
115141 logger * logger.Logger ,
116142) (entityresolution.CreateEntityChainFromJwtResponse , error ) {
117143 entityChains := []* authorization.EntityChain {}
118144 // for each token in the tokens form an entity chain
119145 for _ , tok := range req .GetTokens () {
120- entities , err := getEntitiesFromToken (ctx , kcConfig , tok .GetJwt (), logger )
146+ entities , err := getEntitiesFromToken (ctx , kcConfig , tok .GetJwt (), connector , logger )
121147 if err != nil {
122148 return entityresolution.CreateEntityChainFromJwtResponse {}, err
123149 }
@@ -128,13 +154,8 @@ func CreateEntityChainFromJwt(
128154}
129155
130156func EntityResolution (ctx context.Context ,
131- req * entityresolution.ResolveEntitiesRequest , kcConfig KeycloakConfig , logger * logger.Logger ,
157+ req * entityresolution.ResolveEntitiesRequest , kcConfig KeycloakConfig , connector * KeyCloakConnector , logger * logger.Logger ,
132158) (entityresolution.ResolveEntitiesResponse , error ) {
133- connector , err := getKCClient (ctx , kcConfig , logger )
134- if err != nil {
135- return entityresolution.ResolveEntitiesResponse {},
136- connect .NewError (connect .CodeInternal , ErrCreationFailed )
137- }
138159 payload := req .GetEntities ()
139160
140161 var resolvedEntities []* entityresolution.EntityRepresentation
@@ -334,35 +355,6 @@ func typeToGenericJSONMap[Marshalable any](inputStruct Marshalable, logger *logg
334355 return genericMap , nil
335356}
336357
337- func getKCClient (ctx context.Context , kcConfig KeycloakConfig , logger * logger.Logger ) (* KeyCloakConnector , error ) {
338- var client * gocloak.GoCloak
339- if kcConfig .LegacyKeycloak {
340- logger .Warn ("using legacy connection mode for Keycloak < 17.x.x" )
341- client = gocloak .NewClient (kcConfig .URL )
342- } else {
343- client = gocloak .NewClient (kcConfig .URL , gocloak .SetAuthAdminRealms ("admin/realms" ), gocloak .SetAuthRealms ("realms" ))
344- }
345- // If needed, ability to disable tls checks for testing
346- // restyClient := client.RestyClient()
347- // restyClient.SetDebug(true)
348- // restyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
349- // client.SetRestyClient(restyClient)
350-
351- // For debugging
352- // logger.Debug(kcConfig.ClientID)
353- // logger.Debug(kcConfig.ClientSecret)
354- // logger.Debug(kcConfig.URL)
355- // logger.Debug(kcConfig.Realm)
356- token , err := client .LoginClient (ctx , kcConfig .ClientID , kcConfig .ClientSecret , kcConfig .Realm )
357- if err != nil {
358- logger .Error ("error connecting to keycloak!" , slog .String ("error" , err .Error ()))
359- return nil , err
360- }
361- keycloakConnector := KeyCloakConnector {token : token , client : client }
362-
363- return & keycloakConnector , nil
364- }
365-
366358func expandGroup (ctx context.Context , groupID string , kcConnector * KeyCloakConnector , kcConfig * KeycloakConfig , logger * logger.Logger ) ([]* gocloak.User , error ) {
367359 logger .Info ("expanding group" , slog .String ("groupID" , groupID ))
368360 var entityRepresentations []* gocloak.User
@@ -387,7 +379,7 @@ func expandGroup(ctx context.Context, groupID string, kcConnector *KeyCloakConne
387379 return entityRepresentations , nil
388380}
389381
390- func getEntitiesFromToken (ctx context.Context , kcConfig KeycloakConfig , jwtString string , logger * logger.Logger ) ([]* authorization.Entity , error ) {
382+ func getEntitiesFromToken (ctx context.Context , kcConfig KeycloakConfig , jwtString string , connector * KeyCloakConnector , logger * logger.Logger ) ([]* authorization.Entity , error ) {
391383 token , err := jwt .ParseString (jwtString , jwt .WithVerify (false ), jwt .WithValidate (false ))
392384 if err != nil {
393385 return nil , errors .New ("error parsing jwt " + err .Error ())
@@ -426,7 +418,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtStrin
426418
427419 // double check for service account
428420 if strings .HasPrefix (extractedValueUsernameCasted , serviceAccountUsernamePrefix ) {
429- clientid , err := getServiceAccountClient (ctx , extractedValueUsernameCasted , kcConfig , logger )
421+ clientid , err := getServiceAccountClient (ctx , extractedValueUsernameCasted , kcConfig , connector , logger )
430422 if err != nil {
431423 return nil , err
432424 }
@@ -455,11 +447,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtStrin
455447 return entities , nil
456448}
457449
458- func getServiceAccountClient (ctx context.Context , username string , kcConfig KeycloakConfig , logger * logger.Logger ) (string , error ) {
459- connector , err := getKCClient (ctx , kcConfig , logger )
460- if err != nil {
461- return "" , err
462- }
450+ func getServiceAccountClient (ctx context.Context , username string , kcConfig KeycloakConfig , connector * KeyCloakConnector , logger * logger.Logger ) (string , error ) {
463451 expectedClientName := strings .TrimPrefix (username , serviceAccountUsernamePrefix )
464452
465453 clients , err := connector .client .GetClients (ctx , connector .token .AccessToken , kcConfig .Realm , gocloak.GetClientsParams {
@@ -494,3 +482,39 @@ func entityToStructPb(ident *authorization.Entity) (*structpb.Struct, error) {
494482 }
495483 return & entityStruct , nil
496484}
485+
486+ // getConnector ensures a valid Keycloak connector is available, refreshing the token if necessary.
487+ func (s * KeycloakEntityResolutionService ) getConnector (ctx context.Context , tokenBuffer time.Duration ) (* KeyCloakConnector , error ) {
488+ s .connectorMu .Lock ()
489+ defer s .connectorMu .Unlock ()
490+
491+ // Refresh token if it's nil, expired, or about to expire.
492+
493+ if s .connector == nil || s .connector .token == nil || time .Now ().After (s .connector .expiresAt .Add (- tokenBuffer )) {
494+ s .logger .InfoContext (ctx , "Keycloak connector is nil or token expired/expiring soon. Fetching new token." )
495+
496+ var gocloakClient * gocloak.GoCloak
497+ if s .idpConfig .LegacyKeycloak {
498+ s .logger .WarnContext (ctx , "Using legacy connection mode for Keycloak < 17.x.x" )
499+ gocloakClient = gocloak .NewClient (s .idpConfig .URL )
500+ } else {
501+ gocloakClient = gocloak .NewClient (s .idpConfig .URL , gocloak .SetAuthAdminRealms ("admin/realms" ), gocloak .SetAuthRealms ("realms" ))
502+ }
503+
504+ token , err := gocloakClient .LoginClient (ctx , s .idpConfig .ClientID , s .idpConfig .ClientSecret , s .idpConfig .Realm )
505+ if err != nil {
506+ s .logger .ErrorContext (ctx , "Error connecting to Keycloak or logging in" , slog .String ("error" , err .Error ()))
507+ return nil , fmt .Errorf ("failed to login to Keycloak: %w" , err )
508+ }
509+
510+ s .connector = & KeyCloakConnector {
511+ token : token ,
512+ client : gocloakClient ,
513+ expiresAt : time .Now ().Add (time .Duration (token .ExpiresIn ) * time .Second ),
514+ }
515+ s .logger .InfoContext (ctx , "Successfully fetched new Keycloak token." , "expires_in_seconds" , token .ExpiresIn )
516+ } else {
517+ s .logger .DebugContext (ctx , "Using existing Keycloak token." )
518+ }
519+ return s .connector , nil
520+ }
0 commit comments