Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 66 additions & 49 deletions service/entityresolution/keycloak/entity_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
"log/slog"
"strconv"
"strings"
"sync"
"time"

"connectrpc.com/connect"
"github.com/Nerzal/gocloak/v13"
"github.com/go-viper/mapstructure/v2"

"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/protocol/go/authorization"
"github.com/opentdf/platform/protocol/go/entityresolution"
Expand Down Expand Up @@ -43,6 +46,8 @@ type KeycloakEntityResolutionService struct { //nolint:revive // Too late! Alrea
idpConfig KeycloakConfig
logger *logger.Logger
trace.Tracer
connector *KeyCloakConnector
connectorMu sync.Mutex
}

type KeycloakConfig struct { //nolint:revive // yeah but what if we want to embed multiple configs?
Expand All @@ -65,19 +70,29 @@ func RegisterKeycloakERS(config config.ServiceConfig, logger *logger.Logger) (*K
return keycloakSVC, nil
}

func (s KeycloakEntityResolutionService) ResolveEntities(ctx context.Context, req *connect.Request[entityresolution.ResolveEntitiesRequest]) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) {
func (s *KeycloakEntityResolutionService) ResolveEntities(ctx context.Context, req *connect.Request[entityresolution.ResolveEntitiesRequest]) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) {
ctx, span := s.Tracer.Start(ctx, "ResolveEntities")
defer span.End()

resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, s.logger)
connector, err := s.getConnector(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := EntityResolution(ctx, req.Msg, s.idpConfig, connector, s.logger)

return connect.NewResponse(&resp), err
}

func (s KeycloakEntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *connect.Request[entityresolution.CreateEntityChainFromJwtRequest]) (*connect.Response[entityresolution.CreateEntityChainFromJwtResponse], error) {
func (s *KeycloakEntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *connect.Request[entityresolution.CreateEntityChainFromJwtRequest]) (*connect.Response[entityresolution.CreateEntityChainFromJwtResponse], error) {
ctx, span := s.Tracer.Start(ctx, "CreateEntityChainFromJwt")
defer span.End()

resp, err := CreateEntityChainFromJwt(ctx, req.Msg, s.idpConfig, s.logger)
connector, err := s.getConnector(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("%w: %w", ErrCreationFailed, err))
}
resp, err := CreateEntityChainFromJwt(ctx, req.Msg, s.idpConfig, connector, s.logger)

return connect.NewResponse(&resp), err
}

Expand All @@ -104,20 +119,22 @@ type EntityImpliedFrom struct {
}

type KeyCloakConnector struct { //nolint:revive // Too late! Already exported
token *gocloak.JWT
client *gocloak.GoCloak
token *gocloak.JWT
client *gocloak.GoCloak
expiresAt time.Time
}

func CreateEntityChainFromJwt(
ctx context.Context,
req *entityresolution.CreateEntityChainFromJwtRequest,
kcConfig KeycloakConfig,
connector *KeyCloakConnector,
logger *logger.Logger,
) (entityresolution.CreateEntityChainFromJwtResponse, error) {
entityChains := []*authorization.EntityChain{}
// for each token in the tokens form an entity chain
for _, tok := range req.GetTokens() {
entities, err := getEntitiesFromToken(ctx, kcConfig, tok.GetJwt(), logger)
entities, err := getEntitiesFromToken(ctx, kcConfig, tok.GetJwt(), connector, logger)
if err != nil {
return entityresolution.CreateEntityChainFromJwtResponse{}, err
}
Expand All @@ -128,13 +145,8 @@ func CreateEntityChainFromJwt(
}

func EntityResolution(ctx context.Context,
req *entityresolution.ResolveEntitiesRequest, kcConfig KeycloakConfig, logger *logger.Logger,
req *entityresolution.ResolveEntitiesRequest, kcConfig KeycloakConfig, connector *KeyCloakConnector, logger *logger.Logger,
) (entityresolution.ResolveEntitiesResponse, error) {
connector, err := getKCClient(ctx, kcConfig, logger)
if err != nil {
return entityresolution.ResolveEntitiesResponse{},
connect.NewError(connect.CodeInternal, ErrCreationFailed)
}
payload := req.GetEntities()

var resolvedEntities []*entityresolution.EntityRepresentation
Expand Down Expand Up @@ -334,35 +346,6 @@ func typeToGenericJSONMap[Marshalable any](inputStruct Marshalable, logger *logg
return genericMap, nil
}

func getKCClient(ctx context.Context, kcConfig KeycloakConfig, logger *logger.Logger) (*KeyCloakConnector, error) {
var client *gocloak.GoCloak
if kcConfig.LegacyKeycloak {
logger.Warn("using legacy connection mode for Keycloak < 17.x.x")
client = gocloak.NewClient(kcConfig.URL)
} else {
client = gocloak.NewClient(kcConfig.URL, gocloak.SetAuthAdminRealms("admin/realms"), gocloak.SetAuthRealms("realms"))
}
// If needed, ability to disable tls checks for testing
// restyClient := client.RestyClient()
// restyClient.SetDebug(true)
// restyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
// client.SetRestyClient(restyClient)

// For debugging
// logger.Debug(kcConfig.ClientID)
// logger.Debug(kcConfig.ClientSecret)
// logger.Debug(kcConfig.URL)
// logger.Debug(kcConfig.Realm)
token, err := client.LoginClient(ctx, kcConfig.ClientID, kcConfig.ClientSecret, kcConfig.Realm)
if err != nil {
logger.Error("error connecting to keycloak!", slog.String("error", err.Error()))
return nil, err
}
keycloakConnector := KeyCloakConnector{token: token, client: client}

return &keycloakConnector, nil
}

func expandGroup(ctx context.Context, groupID string, kcConnector *KeyCloakConnector, kcConfig *KeycloakConfig, logger *logger.Logger) ([]*gocloak.User, error) {
logger.Info("expanding group", slog.String("groupID", groupID))
var entityRepresentations []*gocloak.User
Expand All @@ -387,7 +370,7 @@ func expandGroup(ctx context.Context, groupID string, kcConnector *KeyCloakConne
return entityRepresentations, nil
}

func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtString string, logger *logger.Logger) ([]*authorization.Entity, error) {
func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtString string, connector *KeyCloakConnector, logger *logger.Logger) ([]*authorization.Entity, error) {
token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false))
if err != nil {
return nil, errors.New("error parsing jwt " + err.Error())
Expand Down Expand Up @@ -426,7 +409,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtStrin

// double check for service account
if strings.HasPrefix(extractedValueUsernameCasted, serviceAccountUsernamePrefix) {
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, logger)
clientid, err := getServiceAccountClient(ctx, extractedValueUsernameCasted, kcConfig, connector, logger)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -455,11 +438,7 @@ func getEntitiesFromToken(ctx context.Context, kcConfig KeycloakConfig, jwtStrin
return entities, nil
}

func getServiceAccountClient(ctx context.Context, username string, kcConfig KeycloakConfig, logger *logger.Logger) (string, error) {
connector, err := getKCClient(ctx, kcConfig, logger)
if err != nil {
return "", err
}
func getServiceAccountClient(ctx context.Context, username string, kcConfig KeycloakConfig, connector *KeyCloakConnector, logger *logger.Logger) (string, error) {
expectedClientName := strings.TrimPrefix(username, serviceAccountUsernamePrefix)

clients, err := connector.client.GetClients(ctx, connector.token.AccessToken, kcConfig.Realm, gocloak.GetClientsParams{
Expand Down Expand Up @@ -494,3 +473,41 @@ func entityToStructPb(ident *authorization.Entity) (*structpb.Struct, error) {
}
return &entityStruct, nil
}

// getConnector ensures a valid Keycloak connector is available, refreshing the token if necessary.
func (s *KeycloakEntityResolutionService) getConnector(ctx context.Context) (*KeyCloakConnector, error) {
s.connectorMu.Lock()
defer s.connectorMu.Unlock()

// Refresh token if it's nil, expired, or about to expire.
// Define a buffer for token refresh, e.g., 60 seconds before actual expiry.
const tokenRefreshBuffer = 60 * time.Second

if s.connector == nil || s.connector.token == nil || time.Now().After(s.connector.expiresAt.Add(-tokenRefreshBuffer)) {
s.logger.InfoContext(ctx, "Keycloak connector is nil or token expired/expiring soon. Fetching new token.")

var gocloakClient *gocloak.GoCloak
if s.idpConfig.LegacyKeycloak {
s.logger.WarnContext(ctx, "Using legacy connection mode for Keycloak < 17.x.x")
gocloakClient = gocloak.NewClient(s.idpConfig.URL)
} else {
gocloakClient = gocloak.NewClient(s.idpConfig.URL, gocloak.SetAuthAdminRealms("admin/realms"), gocloak.SetAuthRealms("realms"))
}

token, err := gocloakClient.LoginClient(ctx, s.idpConfig.ClientID, s.idpConfig.ClientSecret, s.idpConfig.Realm)
if err != nil {
s.logger.ErrorContext(ctx, "Error connecting to Keycloak or logging in", slog.String("error", err.Error()))
return nil, fmt.Errorf("failed to login to Keycloak: %w", err)
}

s.connector = &KeyCloakConnector{
token: token,
client: gocloakClient,
expiresAt: time.Now().Add(time.Duration(token.ExpiresIn) * time.Second),
}
s.logger.InfoContext(ctx, "Successfully fetched new Keycloak token.", "expires_in_seconds", token.ExpiresIn)
} else {
s.logger.DebugContext(ctx, "Using existing Keycloak token.")
}
return s.connector, nil
}
Loading
Loading