diff --git a/service/entityresolution/claims/claims_entity_resolution.go b/service/entityresolution/claims/claims_entity_resolution.go new file mode 100644 index 0000000000..cde559cc2b --- /dev/null +++ b/service/entityresolution/claims/claims_entity_resolution.go @@ -0,0 +1,141 @@ +package entityresolution + +import ( + "context" + "fmt" + "log/slog" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/opentdf/platform/protocol/go/authorization" + "github.com/opentdf/platform/protocol/go/entityresolution" + auth "github.com/opentdf/platform/service/authorization" + "github.com/opentdf/platform/service/logger" + "github.com/opentdf/platform/service/pkg/serviceregistry" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" +) + +type ClaimsEntityResolutionService struct { + entityresolution.UnimplementedEntityResolutionServiceServer + logger *logger.Logger +} + +func RegisterClaimsERS(_ serviceregistry.ServiceConfig, logger *logger.Logger) (any, serviceregistry.HandlerServer) { + return &ClaimsEntityResolutionService{logger: logger}, + func(ctx context.Context, mux *runtime.ServeMux, server any) error { + return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, server.(entityresolution.EntityResolutionServiceServer)) //nolint:forcetypeassert // allow type assert, following other services + } +} + +func (s ClaimsEntityResolutionService) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { + resp, err := EntityResolution(ctx, req, s.logger) + return &resp, err +} + +func (s ClaimsEntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + resp, err := CreateEntityChainFromJwt(ctx, req, s.logger) + return &resp, err +} + +func CreateEntityChainFromJwt( + _ context.Context, + req *entityresolution.CreateEntityChainFromJwtRequest, + _ *logger.Logger, +) (entityresolution.CreateEntityChainFromJwtResponse, error) { + entityChains := []*authorization.EntityChain{} + // for each token in the tokens form an entity chain + for _, tok := range req.GetTokens() { + entities, err := getEntitiesFromToken(tok.GetJwt()) + if err != nil { + return entityresolution.CreateEntityChainFromJwtResponse{}, err + } + entityChains = append(entityChains, &authorization.EntityChain{Id: tok.GetId(), Entities: entities}) + } + + return entityresolution.CreateEntityChainFromJwtResponse{EntityChains: entityChains}, nil +} + +func EntityResolution(_ context.Context, + req *entityresolution.ResolveEntitiesRequest, logger *logger.Logger, +) (entityresolution.ResolveEntitiesResponse, error) { + payload := req.GetEntities() + var resolvedEntities []*entityresolution.EntityRepresentation + + for idx, ident := range payload { + var entityStruct = &structpb.Struct{} + switch ident.GetEntityType().(type) { + case *authorization.Entity_Claims: + claims := ident.GetClaims() + if claims != nil { + err := claims.UnmarshalTo(entityStruct) + if err != nil { + return entityresolution.ResolveEntitiesResponse{}, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err) + } + } + default: + retrievedStruct, err := entityToStructPb(ident) + if err != nil { + logger.Error("unable to make entity struct", slog.String("error", err.Error())) + return entityresolution.ResolveEntitiesResponse{}, fmt.Errorf("unable to make entity struct: %w", err) + } + entityStruct = retrievedStruct + } + // make sure the id field is populated + originialID := ident.GetId() + if originialID == "" { + originialID = auth.EntityIDPrefix + fmt.Sprint(idx) + } + resolvedEntities = append( + resolvedEntities, + &entityresolution.EntityRepresentation{ + OriginalId: originialID, + AdditionalProps: []*structpb.Struct{entityStruct}, + }, + ) + } + return entityresolution.ResolveEntitiesResponse{EntityRepresentations: resolvedEntities}, nil +} + +func getEntitiesFromToken(jwtString string) ([]*authorization.Entity, error) { + token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + return nil, fmt.Errorf("error parsing jwt: %w", err) + } + + claims := token.PrivateClaims() + entities := []*authorization.Entity{} + + // Convert map[string]interface{} to *structpb.Struct + structClaims, err := structpb.NewStruct(claims) + if err != nil { + return nil, fmt.Errorf("error converting to structpb.Struct: %w", err) + } + + // Wrap the struct in an *anypb.Any message + anyClaims, err := anypb.New(structClaims) + if err != nil { + return nil, fmt.Errorf("error wrapping in anypb.Any: %w", err) + } + + entities = append(entities, &authorization.Entity{ + EntityType: &authorization.Entity_Claims{Claims: anyClaims}, + Id: "jwtentity-claims", + Category: authorization.Entity_CATEGORY_SUBJECT, + }) + return entities, nil +} + +func entityToStructPb(ident *authorization.Entity) (*structpb.Struct, error) { + entityBytes, err := protojson.Marshal(ident) + if err != nil { + return nil, err + } + var entityStruct structpb.Struct + err = entityStruct.UnmarshalJSON(entityBytes) + if err != nil { + return nil, err + } + return &entityStruct, nil +} diff --git a/service/entityresolution/claims/claims_entity_resolution_test.go b/service/entityresolution/claims/claims_entity_resolution_test.go new file mode 100644 index 0000000000..3354d8f332 --- /dev/null +++ b/service/entityresolution/claims/claims_entity_resolution_test.go @@ -0,0 +1,125 @@ +package entityresolution_test + +import ( + "context" + "testing" + + "github.com/opentdf/platform/protocol/go/authorization" + "github.com/opentdf/platform/protocol/go/entityresolution" + claims "github.com/opentdf/platform/service/entityresolution/claims" + "github.com/opentdf/platform/service/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" +) + +const samplejwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6ImhlbGxvd29ybGQiLCJpYXQiOjE1MTYyMzkwMjJ9.EAOittOMzKENEAs44eaMuZe-xas7VNVsgBxhwmxYiIw" + +func Test_ClientResolveEntity(t *testing.T) { + var validBody []*authorization.Entity + validBody = append(validBody, &authorization.Entity{Id: "1234", EntityType: &authorization.Entity_ClientId{ClientId: "random"}}) + + var ctxb = context.Background() + + var req = entityresolution.ResolveEntitiesRequest{} + req.Entities = validBody + + var resp, reserr = claims.EntityResolution(ctxb, &req, logger.CreateTestLogger()) + + require.NoError(t, reserr) + + var entityRepresentations = resp.GetEntityRepresentations() + assert.NotNil(t, entityRepresentations) + assert.Len(t, entityRepresentations, 1) + + assert.Equal(t, "1234", entityRepresentations[0].GetOriginalId()) + assert.Len(t, entityRepresentations[0].GetAdditionalProps(), 1) + var propMap = entityRepresentations[0].GetAdditionalProps()[0].AsMap() + assert.Equal(t, "random", propMap["clientId"]) + assert.Equal(t, "1234", propMap["id"]) +} + +func Test_EmailResolveEntity(t *testing.T) { + var validBody []*authorization.Entity + validBody = append(validBody, &authorization.Entity{Id: "1234", EntityType: &authorization.Entity_EmailAddress{EmailAddress: "random"}}) + + var ctxb = context.Background() + + var req = entityresolution.ResolveEntitiesRequest{} + req.Entities = validBody + + var resp, reserr = claims.EntityResolution(ctxb, &req, logger.CreateTestLogger()) + + require.NoError(t, reserr) + + var entityRepresentations = resp.GetEntityRepresentations() + assert.NotNil(t, entityRepresentations) + assert.Len(t, entityRepresentations, 1) + + assert.Equal(t, "1234", entityRepresentations[0].GetOriginalId()) + assert.Len(t, entityRepresentations[0].GetAdditionalProps(), 1) + var propMap = entityRepresentations[0].GetAdditionalProps()[0].AsMap() + assert.Equal(t, "random", propMap["emailAddress"]) + assert.Equal(t, "1234", propMap["id"]) +} + +func Test_ClaimsResolveEntity(t *testing.T) { + customclaims := map[string]interface{}{ + "foo": "bar", + "baz": 42, + } + // Convert map[string]interface{} to *structpb.Struct + structClaims, err := structpb.NewStruct(customclaims) + require.NoError(t, err) + + // Wrap the struct in an *anypb.Any + anyClaims, err := anypb.New(structClaims) + require.NoError(t, err) + + var validBody []*authorization.Entity + validBody = append(validBody, &authorization.Entity{Id: "1234", EntityType: &authorization.Entity_Claims{Claims: anyClaims}}) + + var ctxb = context.Background() + + var req = entityresolution.ResolveEntitiesRequest{} + req.Entities = validBody + + var resp, reserr = claims.EntityResolution(ctxb, &req, logger.CreateTestLogger()) + + require.NoError(t, reserr) + + var entityRepresentations = resp.GetEntityRepresentations() + assert.NotNil(t, entityRepresentations) + assert.Len(t, entityRepresentations, 1) + + assert.Equal(t, "1234", entityRepresentations[0].GetOriginalId()) + assert.Len(t, entityRepresentations[0].GetAdditionalProps(), 1) + var propMap = entityRepresentations[0].GetAdditionalProps()[0].AsMap() + assert.Equal(t, "bar", propMap["foo"]) + assert.EqualValues(t, 42, propMap["baz"]) +} + +func Test_JWTToEntityChainClaims(t *testing.T) { + var ctxb = context.Background() + + validBody := []*authorization.Token{{Jwt: samplejwt}} + + var resp, reserr = claims.CreateEntityChainFromJwt(ctxb, &entityresolution.CreateEntityChainFromJwtRequest{Tokens: validBody}, logger.CreateTestLogger()) + + require.NoError(t, reserr) + + assert.Len(t, resp.GetEntityChains(), 1) + assert.Len(t, resp.GetEntityChains()[0].GetEntities(), 1) + assert.IsType(t, &authorization.Entity_Claims{}, resp.GetEntityChains()[0].GetEntities()[0].GetEntityType()) + assert.Equal(t, authorization.Entity_CATEGORY_SUBJECT, resp.GetEntityChains()[0].GetEntities()[0].GetCategory()) + + var unpackedStruct structpb.Struct + err := resp.GetEntityChains()[0].GetEntities()[0].GetClaims().UnmarshalTo(&unpackedStruct) + require.NoError(t, err) + + // Convert structpb.Struct to map[string]interface{} + claimsMap := unpackedStruct.AsMap() + + assert.Equal(t, "helloworld", claimsMap["name"]) +} diff --git a/service/entityresolution/entityresolution.go b/service/entityresolution/entityresolution.go index a539e57f63..9fd6c81e28 100644 --- a/service/entityresolution/entityresolution.go +++ b/service/entityresolution/entityresolution.go @@ -1,48 +1,36 @@ package entityresolution import ( - "context" - - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/mitchellh/mapstructure" "github.com/opentdf/platform/protocol/go/entityresolution" + claims "github.com/opentdf/platform/service/entityresolution/claims" keycloak "github.com/opentdf/platform/service/entityresolution/keycloak" - "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/serviceregistry" ) -type EntityResolutionService struct { //nolint:revive // allow for simple naming - entityresolution.UnimplementedEntityResolutionServiceServer - idpConfig keycloak.KeycloakConfig - logger *logger.Logger +type ERSConfig struct { + Mode string `mapstructure:"mode" json:"mode"` } +const KeycloakMode = "keycloak" +const ClaimsMode = "claims" + func NewRegistration() serviceregistry.Registration { return serviceregistry.Registration{ Namespace: "entityresolution", ServiceDesc: &entityresolution.EntityResolutionService_ServiceDesc, RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - var inputIdpConfig keycloak.KeycloakConfig + var inputConfig ERSConfig - if err := mapstructure.Decode(srp.Config, &inputIdpConfig); err != nil { + if err := mapstructure.Decode(srp.Config, &inputConfig); err != nil { panic(err) } - - srp.Logger.Debug("entity_resolution configuration", "config", inputIdpConfig) - - return &EntityResolutionService{idpConfig: inputIdpConfig, logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, server.(entityresolution.EntityResolutionServiceServer)) //nolint:forcetypeassert // allow type assert, following other services + if inputConfig.Mode == ClaimsMode { + return claims.RegisterClaimsERS(srp.Config, srp.Logger) } + + // Default to keyclaok ERS + return keycloak.RegisterKeycloakERS(srp.Config, srp.Logger) }, } } - -func (s EntityResolutionService) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { - resp, err := keycloak.EntityResolution(ctx, req, s.idpConfig, s.logger) - return &resp, err -} - -func (s EntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { - resp, err := keycloak.CreateEntityChainFromJwt(ctx, req, s.idpConfig, s.logger) - return &resp, err -} diff --git a/service/entityresolution/keycloak/keycloak_entity_resolution.go b/service/entityresolution/keycloak/keycloak_entity_resolution.go index c282e3c7e2..a2a9b484cd 100644 --- a/service/entityresolution/keycloak/keycloak_entity_resolution.go +++ b/service/entityresolution/keycloak/keycloak_entity_resolution.go @@ -9,11 +9,14 @@ import ( "strings" "github.com/Nerzal/gocloak/v13" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/mitchellh/mapstructure" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/entityresolution" auth "github.com/opentdf/platform/service/authorization" "github.com/opentdf/platform/service/logger" + "github.com/opentdf/platform/service/pkg/serviceregistry" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" @@ -33,6 +36,12 @@ const ( const serviceAccountUsernamePrefix = "service-account-" +type KeycloakEntityResolutionService struct { + entityresolution.UnimplementedEntityResolutionServiceServer + idpConfig KeycloakConfig + logger *logger.Logger +} + type KeycloakConfig struct { URL string `mapstructure:"url" json:"url"` Realm string `mapstructure:"realm" json:"realm"` @@ -43,6 +52,29 @@ type KeycloakConfig struct { InferID InferredIdentityConfig `mapstructure:"inferid,omitempty" json:"inferid,omitempty"` } +func RegisterKeycloakERS(config serviceregistry.ServiceConfig, logger *logger.Logger) (any, serviceregistry.HandlerServer) { + var inputIdpConfig KeycloakConfig + if err := mapstructure.Decode(config, &inputIdpConfig); err != nil { + panic(err) + } + logger.Debug("entity_resolution configuration", "config", inputIdpConfig) + + return &KeycloakEntityResolutionService{idpConfig: inputIdpConfig, logger: logger}, + func(ctx context.Context, mux *runtime.ServeMux, server any) error { + return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, server.(entityresolution.EntityResolutionServiceServer)) //nolint:forcetypeassert // allow type assert, following other services + } +} + +func (s KeycloakEntityResolutionService) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { + resp, err := EntityResolution(ctx, req, s.idpConfig, s.logger) + return &resp, err +} + +func (s KeycloakEntityResolutionService) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + resp, err := CreateEntityChainFromJwt(ctx, req, s.idpConfig, s.logger) + return &resp, err +} + func (c KeycloakConfig) LogValue() slog.Value { return slog.GroupValue( slog.String("url", c.URL),