From 2f48369afcf94db8c0f3de1e37b070c0018a53e6 Mon Sep 17 00:00:00 2001 From: Ryan Schumacher Date: Mon, 29 Dec 2025 17:48:06 -0600 Subject: [PATCH] feat(authz): add multi group claims support - support multi group claims - refactor for efficiency - refactor for future casbin adapter support --- service/internal/auth/authn.go | 34 +- service/internal/auth/casbin.go | 283 ++++----- service/internal/auth/casbin_csv.go | 117 ++++ service/internal/auth/casbin_test.go | 662 +++++++++++++++------- service/internal/auth/config.go | 38 +- service/internal/auth/dotnotation.go | 22 - service/internal/auth/dotnotation_test.go | 28 - service/pkg/config/config.go | 4 +- service/pkg/util/dotnotation.go | 37 ++ service/pkg/util/dotnotation_test.go | 38 ++ 10 files changed, 838 insertions(+), 425 deletions(-) create mode 100644 service/internal/auth/casbin_csv.go delete mode 100644 service/internal/auth/dotnotation.go delete mode 100644 service/internal/auth/dotnotation_test.go create mode 100644 service/pkg/util/dotnotation.go create mode 100644 service/pkg/util/dotnotation_test.go diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index e4cd0bdf47..3cf58b3728 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -25,9 +25,9 @@ import ( "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/logger/audit" - "google.golang.org/grpc/metadata" - ctxAuth "github.com/opentdf/platform/service/pkg/auth" + "github.com/opentdf/platform/service/pkg/util" + "google.golang.org/grpc/metadata" ) var ( @@ -277,20 +277,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { default: action = ActionUnsafe } - if allow, err := a.enforcer.Enforce(accessTok, r.URL.Path, action); err != nil { - if err.Error() == "permission denied" { - log.WarnContext( - ctx, - "permission denied", - slog.String("azp", accessTok.Subject()), - slog.Any("error", err), - ) - http.Error(w, "permission denied", http.StatusForbidden) - return - } - http.Error(w, "internal server error", http.StatusInternalServerError) - return - } else if !allow { + if !a.enforcer.Enforce(accessTok, nil, r.URL.Path, action) { log.WarnContext( ctx, "permission denied", @@ -367,18 +354,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor } // Check if the token is allowed to access the resource - if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil { - if err.Error() == "permission denied" { - log.WarnContext( - ctxWithJWK, - "permission denied", - slog.String("azp", token.Subject()), - slog.Any("error", err), - ) - return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) - } - return nil, err - } else if !allowed { + if !a.enforcer.Enforce(token, nil, resource, action) { log.WarnContext(ctxWithJWK, "permission denied", slog.String("azp", token.Subject())) return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) } @@ -786,7 +762,7 @@ func (a *Authentication) getClientIDFromToken(ctx context.Context, tok jwt.Token if err != nil { return "", fmt.Errorf("failed to parse token as a map and find claim at [%s]: %w", clientIDClaim, err) } - found := dotNotation(claimsMap, clientIDClaim) + found := util.Dotnotation(claimsMap, clientIDClaim) if found == nil { return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotFound, clientIDClaim) } diff --git a/service/internal/auth/casbin.go b/service/internal/auth/casbin.go index ac9a40f598..e8cc9b56b3 100644 --- a/service/internal/auth/casbin.go +++ b/service/internal/auth/casbin.go @@ -1,7 +1,7 @@ package auth import ( - "errors" + "encoding/json" "fmt" "log/slog" "strings" @@ -11,6 +11,7 @@ import ( stringadapter "github.com/casbin/casbin/v2/persist/string-adapter" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/service/logger" + "github.com/opentdf/platform/service/pkg/util" _ "embed" ) @@ -26,14 +27,15 @@ var builtinPolicy string //go:embed casbin_model.conf var defaultModel string +// Enforcer is the custom Casbin enforcer with additional functionality type Enforcer struct { *casbin.Enforcer - Config CasbinConfig - Policy string - logger *logger.Logger - - isDefaultPolicy bool - isDefaultModel bool + Config CasbinConfig + Policy string // CSV policy string (empty if using non-CSV adapter) + logger *logger.Logger + isDefaultPolicy bool + isDefaultModel bool + groupClaimSelectors [][]string // precomputed selectors for GroupsClaim } type casbinSubject []string @@ -42,7 +44,8 @@ type CasbinConfig struct { PolicyConfig } -// newCasbinEnforcer creates a new casbin enforcer +// NewCasbinEnforcer creates a new Casbin enforcer with the provided configuration and logger. +// It sets up the Casbin model, policy, and adapter, and returns an Enforcer instance. func NewCasbinEnforcer(c CasbinConfig, logger *logger.Logger) (*Enforcer, error) { // Set Casbin config defaults if not provided isDefaultModel := false @@ -51,47 +54,23 @@ func NewCasbinEnforcer(c CasbinConfig, logger *logger.Logger) (*Enforcer, error) isDefaultModel = true } - isDefaultPolicy := false - if c.Csv == "" { - // Set the Builtin Policy if provided - if c.Builtin != "" { - c.Csv = c.Builtin - } else { - c.Csv = builtinPolicy - } - isDefaultPolicy = true - } - - if c.RoleMap != nil { - for k, v := range c.RoleMap { - c.Csv = strings.Join([]string{ - c.Csv, - strings.Join([]string{"g", v, "role:" + k}, ", "), - }, "\n") - } + // Precompute group claim selectors for efficiency (adapter-agnostic) + groupClaimSelectors := make([][]string, len(c.GroupsClaim)) + for i, claim := range c.GroupsClaim { + groupClaimSelectors[i] = strings.Split(claim, ".") } + // Track whether we're using the CSV string adapter (vs custom adapter like SQL) + usingCSVAdapter := c.Adapter == nil + isDefaultPolicy := false isPolicyExtended := false - if c.Extension != "" { - c.Csv = strings.Join([]string{c.Csv, c.Extension}, "\n") - isPolicyExtended = true - } + csvPolicy := "" - // Because we provided built in group mappings we need to add them - // if extensions and rolemap are not provided - if c.RoleMap == nil && c.Extension == "" { - c.Csv = strings.Join([]string{ - c.Csv, - "g, opentdf-admin, role:admin", - "g, opentdf-standard, role:standard", - }, "\n") - } - - isDefaultAdapter := false - // If adapter is not provided, use the default string adapter - if c.Adapter == nil { - isDefaultAdapter = true - c.Adapter = stringadapter.NewAdapter(c.Csv) + // CSV policy building - only when using the default string adapter + // When a custom adapter (e.g., SQL) is provided, skip CSV-specific logic + if usingCSVAdapter { + csvPolicy, isDefaultPolicy, isPolicyExtended = buildCSVPolicy(c) + c.Adapter = stringadapter.NewAdapter(csvPolicy) } logger.Debug("creating casbin enforcer", @@ -99,7 +78,7 @@ func NewCasbinEnforcer(c CasbinConfig, logger *logger.Logger) (*Enforcer, error) slog.Bool("isDefaultModel", isDefaultModel), slog.Bool("isBuiltinPolicy", isDefaultPolicy), slog.Bool("isPolicyExtended", isPolicyExtended), - slog.Bool("isDefaultAdapter", isDefaultAdapter), + slog.Bool("usingCSVAdapter", usingCSVAdapter), ) m, err := casbinModel.NewModelFromString(c.Model) @@ -112,130 +91,164 @@ func NewCasbinEnforcer(c CasbinConfig, logger *logger.Logger) (*Enforcer, error) return nil, fmt.Errorf("failed to create casbin enforcer: %w", err) } + // Explicitly load the policy from the adapter + if err := e.LoadPolicy(); err != nil { + return nil, fmt.Errorf("failed to load casbin policy: %w", err) + } + + // CSV validation - only for CSV/string adapters + // Skip validation for custom adapters (e.g., SQL) which have their own validation + if usingCSVAdapter { + if err := validateCSVPolicy(csvPolicy); err != nil { + return nil, err + } + } + return &Enforcer{ - Enforcer: e, - Config: c, - Policy: c.Csv, - isDefaultPolicy: isDefaultPolicy, - isDefaultModel: isDefaultModel, - logger: logger, + Enforcer: e, + Config: c, + Policy: csvPolicy, // Empty string if using non-CSV adapter + isDefaultPolicy: isDefaultPolicy, + isDefaultModel: isDefaultModel, + logger: logger, + groupClaimSelectors: groupClaimSelectors, }, nil } -// casbinEnforce is a helper function to enforce the policy with casbin -// TODO implement a common type so this can be used for both http and grpc -func (e *Enforcer) Enforce(token jwt.Token, resource, action string) (bool, error) { - // extract the role claim from the token - s := e.buildSubjectFromToken(token) - s = append(s, rolePrefix+defaultRole) +// Enforce checks if the given token and userInfo are allowed to perform the action on the resource. +// It extracts roles from both the token and userInfo, then checks against the Casbin policy. +func (e *Enforcer) Enforce(token jwt.Token, userInfo []byte, resource, action string) bool { + // Fail-safe: deny if resource or action is empty + if resource == "" || action == "" { + e.logger.Debug("permission denied: empty resource or action", slog.String("resource", resource), slog.String("action", action)) + return false + } + + // extract the role claim from the token and userInfo + s := e.buildSubjectFromTokenAndUserInfo(token, userInfo) + + // Assign the default role if no roles are found + if len(s) == 0 { + s = append(s, rolePrefix+defaultRole) + } for _, info := range s { allowed, err := e.Enforcer.Enforce(info, resource, action) if err != nil { - e.logger.Error("enforce by role error", - slog.String("subject_info", info), - slog.String("action", action), - slog.String("resource", resource), - slog.Any("error", err), - ) + e.logger.Error("enforce by role error", slog.String("subject info", info), slog.String("resource", resource), slog.String("action", action), slog.String("error", err.Error())) } if allowed { - e.logger.Debug("allowed by policy", - slog.String("subject_info", info), - slog.String("action", action), - slog.String("resource", resource), - ) - return true, nil + e.logger.Debug("allowed by policy", slog.String("subject info", info), slog.String("resource", resource), slog.String("action", action)) + return true } } - e.logger.Debug("permission denied by policy", - slog.Any("subject_info", s), - slog.String("action", action), - slog.String("resource", resource), - ) - return false, errors.New("permission denied") + e.logger.Debug("permission denied by policy", slog.Any("subject.info", s), slog.String("resource", resource), slog.String("action", action)) + return false } -func (e *Enforcer) buildSubjectFromToken(t jwt.Token) casbinSubject { +// buildSubjectFromTokenAndUserInfo combines roles from both token and userInfo +func (e *Enforcer) buildSubjectFromTokenAndUserInfo(t jwt.Token, userInfo []byte) casbinSubject { var subject string info := casbinSubject{} - e.logger.Debug("building subject from token") + e.logger.Debug("building subject from token and userInfo", slog.Any("token", t), slog.Any("userInfo", userInfo)) roles := e.extractRolesFromToken(t) + roles = append(roles, e.extractRolesFromUserInfo(userInfo)...) + + for _, r := range roles { + if r != "" { + info = append(info, r) + } + } if claim, found := t.Get(e.Config.UserNameClaim); found { sub, ok := claim.(string) subject = sub if !ok { - e.logger.Warn("username claim not of type string", - slog.String("claim", e.Config.UserNameClaim), - slog.Any("claims", claim), - ) + e.logger.Warn("username claim not of type string", slog.String("claim", e.Config.UserNameClaim), slog.Any("claims", claim)) subject = "" } } - info = append(info, roles...) - info = append(info, subject) + if subject != "" { + info = append(info, subject) + } + e.logger.Debug("built subject info", slog.Any("info", info)) return info } -func (e *Enforcer) extractRolesFromToken(t jwt.Token) []string { - e.logger.Debug("extracting roles from token") - roles := []string{} - - roleClaim := e.Config.GroupsClaim - // roleMap := e.Config.RoleMap - - selectors := strings.Split(roleClaim, ".") - claim, exists := t.Get(selectors[0]) - if !exists { - e.logger.Warn("claim not found", - slog.String("claim", roleClaim), - slog.Any("claims", claim), - ) - return nil - } - e.logger.Debug("root claim found", - slog.String("claim", roleClaim), - slog.Any("claims", claim), - ) - // use dotnotation if the claim is nested - if len(selectors) > 1 { - claimMap, ok := claim.(map[string]interface{}) - if !ok { - e.logger.Warn("claim is not of type map[string]interface{}", - slog.String("claim", roleClaim), - slog.Any("claims", claim), - ) - return nil +const ( + // defaultRolesCapacity is the default capacity for roles slice in extractRolesFromToken + defaultRolesCapacity = 4 +) + +// extractRolesFromToken extracts roles from a jwt.Token based on the configured claim path +func (e *Enforcer) extractRolesFromToken(token jwt.Token) []string { + roles := make([]string, 0, defaultRolesCapacity) // preallocate for common case + for _, selectors := range e.groupClaimSelectors { + if len(selectors) == 0 { + continue } - claim = dotNotation(claimMap, strings.Join(selectors[1:], ".")) - if claim == nil { - e.logger.Warn("claim not found", - slog.String("claim", roleClaim), - slog.Any("claims", claim), - ) - return nil + claim, exists := token.Get(selectors[0]) + if !exists { + continue // skip missing claim, don't log on hot path + } + if len(selectors) > 1 { + claimMap, ok := claim.(map[string]interface{}) + if !ok { + continue // skip invalid type + } + claim = util.Dotnotation(claimMap, strings.Join(selectors[1:], ".")) + if claim == nil { + continue + } + } + // Inline extractRolesFromClaim for efficiency + switch v := claim.(type) { + case string: + roles = append(roles, v) + case []interface{}: + for _, rr := range v { + if r, ok := rr.(string); ok { + roles = append(roles, r) + } + } + case []string: + roles = append(roles, v...) } } + return roles +} - // check the type of the role claim - switch v := claim.(type) { - case string: - roles = append(roles, v) - case []interface{}: - for _, rr := range v { - if r, ok := rr.(string); ok { - roles = append(roles, r) +// extractRolesFromUserInfo extracts roles from a userInfo JSON ([]byte) based on the configured claim path +func (e *Enforcer) extractRolesFromUserInfo(userInfo []byte) []string { + roles := make([]string, 0, defaultRolesCapacity) + if userInfo == nil || len(userInfo) == 0 { + return roles + } + var userInfoMap map[string]interface{} + if err := json.Unmarshal(userInfo, &userInfoMap); err != nil { + return roles // skip logging on hot path + } + for _, selectors := range e.groupClaimSelectors { + if len(selectors) == 0 { + continue + } + claim := util.Dotnotation(userInfoMap, strings.Join(selectors, ".")) + if claim == nil { + continue + } + switch v := claim.(type) { + case string: + roles = append(roles, v) + case []interface{}: + for _, rr := range v { + if r, ok := rr.(string); ok { + roles = append(roles, r) + } } + case []string: + roles = append(roles, v...) } - default: - e.logger.Warn("could not get claim type", - slog.String("selector", roleClaim), - slog.Any("claims", claim), - ) - return nil } - return roles } diff --git a/service/internal/auth/casbin_csv.go b/service/internal/auth/casbin_csv.go new file mode 100644 index 0000000000..16fa4e49e6 --- /dev/null +++ b/service/internal/auth/casbin_csv.go @@ -0,0 +1,117 @@ +package auth + +import ( + "fmt" + "strings" +) + +// csvPolicyBuilder handles building CSV policy strings from configuration. +// This is separated to make it easier to support other adapters (e.g., SQL) in the future. +type csvPolicyBuilder struct { + basePolicy string + lines []string +} + +// newCSVPolicyBuilder creates a new CSV policy builder with the base policy. +func newCSVPolicyBuilder(basePolicy string) *csvPolicyBuilder { + return &csvPolicyBuilder{ + basePolicy: basePolicy, + lines: []string{basePolicy}, + } +} + +// addRoleMapping adds a group-to-role mapping line (g, user, role). +func (b *csvPolicyBuilder) addRoleMapping(user, role string) { + b.lines = append(b.lines, fmt.Sprintf("g, %s, role:%s", user, role)) +} + +// addExtension appends extension policy lines. +func (b *csvPolicyBuilder) addExtension(extension string) { + if extension != "" { + b.lines = append(b.lines, extension) + } +} + +// build returns the complete CSV policy string. +func (b *csvPolicyBuilder) build() string { + return strings.Join(b.lines, "\n") +} + +// validateCSVPolicy validates a CSV policy string for correct format. +// This validation is specific to CSV/string adapters and should be skipped for other adapters (e.g., SQL). +func validateCSVPolicy(csv string) error { + policyLines := strings.Split(csv, "\n") + for i, line := range policyLines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue // skip empty/comment lines + } + fields := strings.Split(line, ",") + for j := range fields { + fields[j] = strings.TrimSpace(fields[j]) + } + switch fields[0] { + case "p": + // Policy line: expect at least 5 fields: p, sub, obj, act, eft + const expectedFields = 5 + if len(fields) < expectedFields { + return fmt.Errorf("malformed casbin policy line %d: %q (expected at least 5 fields)", i+1, line) + } + + sub, obj, act, eft := fields[1], fields[2], fields[3], fields[4] + if sub == "" || obj == "" || act == "" { + return fmt.Errorf("malformed casbin policy line %d: %q (resource and action fields must not be empty)", i+1, line) + } + if eft != "allow" && eft != "deny" { + return fmt.Errorf("malformed casbin policy line %d: %q (effect must be 'allow' or 'deny')", i+1, line) + } + case "g": + const expectedFields = 3 + // Grouping line: expect at least 3 fields: g, user, role + if len(fields) < expectedFields { + return fmt.Errorf("malformed casbin grouping line %d: %q (expected at least 3 fields)", i+1, line) + } + default: + // Unknown line type, fail-safe: error + return fmt.Errorf("malformed casbin policy line %d: %q (unknown line type, must start with 'p' or 'g')", i+1, line) + } + } + return nil +} + +// buildCSVPolicy constructs the CSV policy string from configuration. +// Returns the policy string, whether it's the default policy, and whether it was extended. +func buildCSVPolicy(c CasbinConfig) (policy string, isDefault bool, isExtended bool) { + // Determine base policy + basePolicy := c.Csv + isDefault = false + if basePolicy == "" { + if c.Builtin != "" { + basePolicy = c.Builtin + } else { + basePolicy = builtinPolicy + } + isDefault = true + } + + builder := newCSVPolicyBuilder(basePolicy) + + // Add role mappings from RoleMap + for role, user := range c.RoleMap { + builder.addRoleMapping(user, role) + } + + // Add extension policy + if c.Extension != "" { + builder.addExtension(c.Extension) + isExtended = true + } + + // Add default group mappings if no RoleMap or Extension provided + if c.RoleMap == nil && c.Extension == "" { + builder.addRoleMapping("opentdf-admin", "admin") + builder.addRoleMapping("opentdf-standard", "standard") + } + + return builder.build(), isDefault, isExtended +} diff --git a/service/internal/auth/casbin_test.go b/service/internal/auth/casbin_test.go index a67f45fb0b..cb4c19551b 100644 --- a/service/internal/auth/casbin_test.go +++ b/service/internal/auth/casbin_test.go @@ -80,9 +80,7 @@ func (s *AuthnCasbinSuite) Test_NewEnforcerWithCustomModel() { }) s.Require().NoError(err) - allowed, err := enforcer.Enforce(tok, "", "") - s.Require().NoError(err) - s.True(allowed) + s.True(enforcer.Enforce(tok, nil, "res", "act")) } func (s *AuthnCasbinSuite) Test_NewEnforcerWithBadCustomModel() { @@ -243,31 +241,29 @@ func (s *AuthnCasbinSuite) Test_Enforcement() { slog.Info("running test w/ default claim", slog.String("name", name)) enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) s.Require().NoError(err, name) - tok := s.newTokWithDefaultClaim(test.roles[0], test.roles[1], "", "") - allowed, err := enforcer.Enforce(tok, test.resource, test.action) - if !test.allowed { - s.Require().Error(err, name) + tok := s.newTokWithDefaultClaim(test.roles[0], test.roles[1], "") + allowed := enforcer.Enforce(tok, nil, test.resource, test.action) + if test.allowed { + s.True(allowed, name) } else { - s.Require().NoError(err, name) + s.False(allowed, name) } - s.Equal(test.allowed, allowed, name) slog.Info("running test w/ custom claim", slog.String("name", name)) - policyCfg.GroupsClaim = "test.test_roles.roles" + policyCfg.GroupsClaim = GroupsClaimList{"test.test_roles.roles"} enforcer, err = NewCasbinEnforcer(CasbinConfig{ PolicyConfig: policyCfg, }, logger.CreateTestLogger()) s.Require().NoError(err, name) _, tok = s.newTokenWithCustomClaim(test.roles[0], test.roles[1]) - allowed, err = enforcer.Enforce(tok, test.resource, test.action) - if !test.allowed { - s.Require().Error(err, name) + allowed = enforcer.Enforce(tok, nil, test.resource, test.action) + if test.allowed { + s.True(allowed, name) } else { - s.Require().NoError(err, name) + s.False(allowed, name) } - s.Equal(test.allowed, allowed, name) slog.Info("running test w/ custom rolemap", slog.String("name", name)) @@ -275,20 +271,19 @@ func (s *AuthnCasbinSuite) Test_Enforcement() { "admin": "test-admin", "standard": "test-standard", } - policyCfg.GroupsClaim = "realm_access.roles" + policyCfg.GroupsClaim = GroupsClaimList{"realm_access.roles"} enforcer, err = NewCasbinEnforcer(CasbinConfig{ PolicyConfig: policyCfg, }, logger.CreateTestLogger()) s.Require().NoError(err, name) _, tok = s.newTokenWithCustomRoleMap(test.roles[0], test.roles[1]) - allowed, err = enforcer.Enforce(tok, test.resource, test.action) - if !test.allowed { - s.Require().Error(err, name) + allowed = enforcer.Enforce(tok, nil, test.resource, test.action) + if test.allowed { + s.True(allowed, name) } else { - s.Require().NoError(err, name) + s.False(allowed, name) } - s.Equal(test.allowed, allowed) slog.Info("running test w/ client_id", slog.String("name", name)) roleMap := make(map[string]string) @@ -307,13 +302,12 @@ func (s *AuthnCasbinSuite) Test_Enforcement() { }, logger.CreateTestLogger()) s.Require().NoError(err, name) _, tok = s.newTokenWithCilentID() - allowed, err = enforcer.Enforce(tok, test.resource, test.action) - if !test.allowed { - s.Require().Error(err, name) + allowed = enforcer.Enforce(tok, nil, test.resource, test.action) + if test.allowed { + s.True(allowed, name) } else { - s.Require().NoError(err, name) + s.False(allowed, name) } - s.Equal(test.allowed, allowed, name) } } @@ -331,93 +325,90 @@ func (s *AuthnCasbinSuite) Test_ExtendDefaultPolicies() { enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) s.Require().NoError(err) // other roles denied new policy: admin - tok := s.newTokWithDefaultClaim(true, false, "", "") - allowed, err := enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().NoError(err) + tok := s.newTokWithDefaultClaim(true, false, "") + allowed := enforcer.Enforce(tok, nil, "new.service.DoSomething", "read") s.True(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") - s.Require().NoError(err) + allowed = enforcer.Enforce(tok, nil, "new.service.DoSomething", "write") s.True(allowed) // other roles denied new policy: standard - tok = s.newTokWithDefaultClaim(false, true, "", "") - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().NoError(err) + tok = s.newTokWithDefaultClaim(false, true, "") + allowed = enforcer.Enforce(tok, nil, "new.service.DoSomething", "read") s.True(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") - s.Require().Error(err) + allowed = enforcer.Enforce(tok, nil, "new.service.DoSomething", "write") s.False(allowed) } func (s *AuthnCasbinSuite) Test_ExtendDefaultPolicies_MalformedErrors() { - policyCfg := PolicyConfig{} - err := defaults.Set(&policyCfg) - s.Require().NoError(err) - - enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) - s.Require().NoError(err) - tok := s.newTokWithDefaultClaim(true, false, "", "") - allowed, err := enforcer.Enforce(tok, "policy.attributes.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) - - // missing 'p' - policyCfg.Extension = strings.Join([]string{ - "g, opentdf-admin, role:admin", - "g, opentdf-standard, role:standard", - "role:admin, new.service.DoSomething, *", - }, "\n") - enforcer, err = NewCasbinEnforcer(CasbinConfig{ - PolicyConfig: policyCfg, - }, logger.CreateTestLogger()) - s.Require().NoError(err) - tok = s.newTokWithDefaultClaim(true, false, "", "") - allowed, err = enforcer.Enforce(tok, "policy.attributes.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) - - // missing effect - policyCfg.Extension = strings.Join([]string{ - "g, opentdf-admin, role:admin", - "g, opentdf-standard, role:standard", - "p, role:admin, new.service.DoSomething, *", - }, "\n") - enforcer, err = NewCasbinEnforcer(CasbinConfig{ - PolicyConfig: policyCfg, - }, logger.CreateTestLogger()) - s.Require().NoError(err) - tok = s.newTokWithDefaultClaim(true, false, "", "") - allowed, err = enforcer.Enforce(tok, "policy.attributes.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) - - // empty - policyCfg.Extension = strings.Join([]string{ - "", - }, "\n") - enforcer, err = NewCasbinEnforcer(CasbinConfig{ - PolicyConfig: policyCfg, - }, logger.CreateTestLogger()) - s.Require().NoError(err) - tok = s.newTokWithDefaultClaim(true, false, "", "") - allowed, err = enforcer.Enforce(tok, "policy.attributes.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) + testCases := []struct { + name string + extension string + expectErr bool + allowed bool // expected result from enforce + }{ + { + name: "admin no extension, empty resource or action", + extension: "", + expectErr: false, + allowed: true, + }, + { + name: "missing 'p' in policy line", + extension: strings.Join([]string{ + "g, opentdf-admin, role:admin", + "g, opentdf-standard, role:standard", + "role:admin, new.service.DoSomething, *", + }, "\n"), + expectErr: true, // now expect error due to malformed policy line + allowed: false, // fail-safe: should deny + }, + { + name: "missing effect", + extension: strings.Join([]string{ + "g, opentdf-admin, role:admin", + "g, opentdf-standard, role:standard", + "p, role:admin, new.service.DoSomething, *", + }, "\n"), + expectErr: true, // now expect error due to malformed policy line + allowed: false, // fail-safe: should deny + }, + { + name: "missing role prefix", + extension: strings.Join([]string{ + "g, opentdf-admin, admin", + "g, opentdf-standard, standard", + "p, admin, new.service.DoSomething, *", + }, "\n"), + expectErr: true, // now expect error due to malformed policy line + allowed: false, + }, + } - // missing role prefix - policyCfg.Extension = strings.Join([]string{ - "g, opentdf-admin, role:admin", - "g, opentdf-standard, role:standard", - "p, admin, new.service.DoSomething, *", - }, "\n") - enforcer, err = NewCasbinEnforcer(CasbinConfig{ - PolicyConfig: policyCfg, - }, logger.CreateTestLogger()) - s.Require().NoError(err) - tok = s.newTokWithDefaultClaim(true, false, "", "") - allowed, err = enforcer.Enforce(tok, "policy.attributes.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) + for _, tc := range testCases { + s.Run(tc.name, func() { + policyCfg := PolicyConfig{} + err := defaults.Set(&policyCfg) + s.Require().NoError(err) + policyCfg.Extension = tc.extension + enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) + if tc.expectErr { + s.Require().Error(err) + s.Nil(enforcer) + return + } + + s.Require().NoError(err) + s.NotNil(enforcer) + + tok := s.newTokWithDefaultClaim(true, false, "") + allowed := enforcer.Enforce(tok, nil, "policy.attributes.DoSomething", "read") + if tc.allowed { + s.True(allowed) + } else { + s.False(allowed) + } + }) + } } func (s *AuthnCasbinSuite) Test_SetBuiltinPolicy() { @@ -433,118 +424,378 @@ func (s *AuthnCasbinSuite) Test_SetBuiltinPolicy() { "g, opentdf-standard, role:standard", }, "\n") - enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) - s.Require().NoError(err) - - // unauthorized role - tok := s.newTokWithDefaultClaim(false, false, "", "") - allowed, err := enforcer.Enforce(tok, "new.hello.World", "read") - s.Require().Error(err) - s.False(allowed) - allowed, err = enforcer.Enforce(tok, "new.hello.World", "write") - s.Require().Error(err) - s.False(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().Error(err) - s.False(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") - s.Require().Error(err) - s.False(allowed) - - // other roles denied new policy: admin - tok = s.newTokWithDefaultClaim(true, false, "", "") - allowed, err = enforcer.Enforce(tok, "new.hello.World", "read") - s.Require().NoError(err) - s.True(allowed) - allowed, err = enforcer.Enforce(tok, "new.hello.World", "write") - s.Require().NoError(err) - s.True(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().Error(err) - s.False(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") - s.Require().Error(err) - s.False(allowed) - - // other roles denied new policy: standard - tok = s.newTokWithDefaultClaim(false, true, "", "") - allowed, err = enforcer.Enforce(tok, "new.hello.World", "read") - s.Require().NoError(err) - s.True(allowed) - allowed, err = enforcer.Enforce(tok, "new.hello.World", "write") - s.Require().Error(err) - s.False(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().Error(err) - s.False(allowed) - allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") - s.Require().Error(err) - s.False(allowed) -} - -func (s *AuthnCasbinSuite) Test_Username_Policy() { - policyCfg := PolicyConfig{} - err := defaults.Set(&policyCfg) - s.Require().NoError(err) - - policyCfg.Extension = strings.Join([]string{ - "p, casbin-user, new.service.*, read, allow", - }, "\n") + testCases := []struct { + name string + admin bool + standard bool + resource string + action string + allowed bool + }{ + { + name: "unauthorized role cannot read new.hello.World", + admin: false, + standard: false, + resource: "new.hello.World", + action: "read", + allowed: false, + }, + { + name: "unauthorized role cannot write new.hello.World", + admin: false, + standard: false, + resource: "new.hello.World", + action: "write", + allowed: false, + }, + { + name: "unauthorized role cannot read new.service.DoSomething", + admin: false, + standard: false, + resource: "new.service.DoSomething", + action: "read", + allowed: false, + }, + { + name: "unauthorized role cannot write new.service.DoSomething", + admin: false, + standard: false, + resource: "new.service.DoSomething", + action: "write", + allowed: false, + }, + { + name: "admin can read new.hello.World", + admin: true, + standard: false, + resource: "new.hello.World", + action: "read", + allowed: true, + }, + { + name: "admin can write new.hello.World", + admin: true, + standard: false, + resource: "new.hello.World", + action: "write", + allowed: true, + }, + { + name: "admin cannot read new.service.DoSomething", + admin: true, + standard: false, + resource: "new.service.DoSomething", + action: "read", + allowed: false, + }, + { + name: "admin cannot write new.service.DoSomething", + admin: true, + standard: false, + resource: "new.service.DoSomething", + action: "write", + allowed: false, + }, + { + name: "standard can read new.hello.World", + admin: false, + standard: true, + resource: "new.hello.World", + action: "read", + allowed: true, + }, + { + name: "standard cannot write new.hello.World", + admin: false, + standard: true, + resource: "new.hello.World", + action: "write", + allowed: false, + }, + { + name: "standard cannot read new.service.DoSomething", + admin: false, + standard: true, + resource: "new.service.DoSomething", + action: "read", + allowed: false, + }, + { + name: "standard cannot write new.service.DoSomething", + admin: false, + standard: true, + resource: "new.service.DoSomething", + action: "write", + allowed: false, + }, + } enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) s.Require().NoError(err) - tok := s.newTokWithDefaultClaim(true, false, "preferred_username", "") - allowed, err := enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) - - allowed, err = enforcer.Enforce(tok, "policy.attributes.List", "read") - s.Require().Error(err) - s.False(allowed) + for _, tc := range testCases { + s.Run(tc.name, func() { + tok := s.newTokWithDefaultClaim(tc.admin, tc.standard, "") + allowed := enforcer.Enforce(tok, nil, tc.resource, tc.action) + if tc.allowed { + s.True(allowed, tc.name) + } else { + s.False(allowed, tc.name) + } + }) + } } -func (s *AuthnCasbinSuite) Test_Override_Of_Username_Claim() { - policyCfg := PolicyConfig{} - err := defaults.Set(&policyCfg) - s.Require().NoError(err) +func (s *AuthnCasbinSuite) Test_Username_Claim_Enforcement() { + tests := []struct { + name string + usernameClaim string + resource string + action string + shouldAllow bool + setClaim bool // whether to set the username claim in the token + }{ + { + name: "Allow with correct username claim (override)", + usernameClaim: "username", + resource: "new.service.DoSomething", + action: "read", + shouldAllow: true, + setClaim: true, + }, + { + name: "Deny with incorrect resource (override)", + usernameClaim: "username", + resource: "policy.attributes.List", + action: "read", + shouldAllow: false, + setClaim: true, + }, + { + name: "Allow with correct username claim (default)", + usernameClaim: "preferred_username", + resource: "new.service.DoSomething", + action: "read", + shouldAllow: true, + setClaim: true, + }, + { + name: "Deny with incorrect resource (default)", + usernameClaim: "preferred_username", + resource: "policy.attributes.List", + action: "read", + shouldAllow: false, + setClaim: true, + }, + { + name: "Deny when username claim not set in token", + usernameClaim: "username", + resource: "new.service.DoSomething", + action: "read", + shouldAllow: false, + setClaim: false, + }, + } - policyCfg.UserNameClaim = "username" - policyCfg.Extension = strings.Join([]string{ - "p, casbin-user, new.service.*, read, allow", - }, "\n") + for _, tc := range tests { + policyCfg := PolicyConfig{} + err := defaults.Set(&policyCfg) + s.Require().NoError(err, tc.name) - enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) - s.Require().NoError(err) + policyCfg.UserNameClaim = tc.usernameClaim + policyCfg.Extension = strings.Join([]string{ + "p, casbin-user, new.service.*, read, allow", + }, "\n") - tok := s.newTokWithDefaultClaim(true, false, "username", "") - allowed, err := enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().NoError(err) - s.True(allowed) + enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) + s.Require().NoError(err, tc.name) - allowed, err = enforcer.Enforce(tok, "policy.attributes.List", "read") - s.Require().Error(err) - s.False(allowed) + var tok jwt.Token + if tc.setClaim { + tok = s.newTokWithDefaultClaim(true, false, tc.usernameClaim) + } else { + tok = s.newTokWithDefaultClaim(true, false, "") + } + + allowed := enforcer.Enforce(tok, nil, tc.resource, tc.action) + if tc.shouldAllow && allowed { + s.True(allowed, tc.name) + } else { + s.False(allowed, tc.name) + } + } } -func (s *AuthnCasbinSuite) Test_Override_Of_Groups_Claim() { +func (s *AuthnCasbinSuite) Test_Casbin_Claims_Matrix() { + type scenario struct { + name string + groupsClaim GroupsClaimList + tokenClaims map[string]interface{} + userInfo []byte + shouldAllow bool + description string + } + policyCfg := PolicyConfig{} err := defaults.Set(&policyCfg) s.Require().NoError(err) + policyCfg.Extension = "p, role:admin, resource, read, allow" - policyCfg.GroupsClaim = "realm_access.groups" - - enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: policyCfg}, logger.CreateTestLogger()) - s.Require().NoError(err) - - tok := s.newTokWithDefaultClaim(false, true, "", "groups") - allowed, err := enforcer.Enforce(tok, "new.service.DoSomething", "read") - s.Require().Error(err) - s.False(allowed) + testMatrix := []scenario{ + { + name: "One claim supported (token)", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"admin"}}}, + shouldAllow: true, + description: "Token with one supported claim should allow", + }, + { + name: "Multiple claims supported (token)", + groupsClaim: GroupsClaimList{"realm_access.roles", "custom.roles"}, + tokenClaims: map[string]interface{}{"custom": map[string]interface{}{"roles": []interface{}{"admin"}}}, + shouldAllow: true, + description: "Token with any supported claim should allow", + }, + { + name: "No claims in token or userInfo", + groupsClaim: GroupsClaimList{"realm_access.roles", "custom.roles"}, + tokenClaims: map[string]interface{}{}, + shouldAllow: false, + description: "No claims present should deny", + }, + { + name: "Access token contains claim", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"admin"}}}, + shouldAllow: true, + description: "Access token contains claim should allow", + }, + { + name: "User info contains claim", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{}, + userInfo: []byte(`{"realm_access": {"roles": ["admin"]}}`), + shouldAllow: true, + description: "User info contains claim should allow", + }, + { + name: "Both token and userInfo have matching claim", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"admin"}}}, + userInfo: []byte(`{"realm_access": {"roles": ["admin"]}}`), + shouldAllow: true, + description: "Should allow if either token or userInfo matches", + }, + { + name: "Token has non-matching, userInfo has matching claim", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"other"}}}, + userInfo: []byte(`{"realm_access": {"roles": ["admin"]}}`), + shouldAllow: true, + description: "Should allow if userInfo matches even if token does not", + }, + { + name: "Both token and userInfo have non-matching claims", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"other"}}}, + userInfo: []byte(`{"realm_access": {"roles": ["other2"]}}`), + shouldAllow: false, + description: "Should deny if neither token nor userInfo matches", + }, + { + name: "Malformed userInfo JSON", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{}, + userInfo: []byte(`not-a-json`), + shouldAllow: false, + description: "Should deny and not panic on malformed userInfo JSON", + }, + { + name: "GroupsClaim with nested path that doesn't exist", + groupsClaim: GroupsClaimList{"nonexistent.path.roles"}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"admin"}}}, + shouldAllow: false, + description: "Should deny if claim path does not exist", + }, + { + name: "GroupsClaim as empty list", + groupsClaim: GroupsClaimList{}, + tokenClaims: map[string]interface{}{"realm_access": map[string]interface{}{"roles": []interface{}{"admin"}}}, + shouldAllow: false, + description: "Should deny if no groups claim configured", + }, + { + name: "GroupsClaim with multiple, only one matches", + groupsClaim: GroupsClaimList{"realm_access.roles", "custom.roles"}, + tokenClaims: map[string]interface{}{"custom": map[string]interface{}{"roles": []interface{}{"admin"}}}, + shouldAllow: true, + description: "Should allow if any claim in GroupsClaim matches", + }, + { + name: "GroupsClaim with multiple, all match", + groupsClaim: GroupsClaimList{"realm_access.roles", "custom.roles"}, + tokenClaims: map[string]interface{}{ + "realm_access": map[string]interface{}{"roles": []interface{}{"admin"}}, + "custom": map[string]interface{}{"roles": []interface{}{"admin"}}, + }, + shouldAllow: true, + description: "Should allow if all claims in GroupsClaim match", + }, + { + name: "UserInfo present but empty", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{}, + userInfo: []byte(`{}`), + shouldAllow: false, + description: "Should deny if userInfo is present but empty", + }, + { + name: "Token and UserInfo both empty", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{}, + userInfo: nil, + shouldAllow: false, + description: "Should deny if both token and userInfo are empty", + }, + { + name: "UserInfo nil or empty length", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{}, + userInfo: nil, + shouldAllow: false, + description: "Should deny if userInfo is nil", + }, + { + name: "UserInfo empty slice", + groupsClaim: GroupsClaimList{"realm_access.roles"}, + tokenClaims: map[string]interface{}{}, + userInfo: []byte{}, + shouldAllow: false, + description: "Should deny if userInfo is empty slice", + }, + } - allowed, err = enforcer.Enforce(tok, "policy.attributes.List", "read") - s.Require().NoError(err) - s.True(allowed) + for _, tc := range testMatrix { + s.Run(tc.name, func() { + cfg := policyCfg + cfg.GroupsClaim = tc.groupsClaim + enforcer, err := NewCasbinEnforcer(CasbinConfig{PolicyConfig: cfg}, logger.CreateTestLogger()) + s.Require().NoError(err) + + tok := jwt.New() + for k, v := range tc.tokenClaims { + if err := tok.Set(k, v); err != nil { + s.Fail("Failed to set token claim", err) + } + } + allowed := enforcer.Enforce(tok, tc.userInfo, "resource", "read") + if tc.shouldAllow && allowed { + s.True(allowed, tc.description) + } else { + s.False(allowed, tc.description) + } + }) + } } func (s *AuthnCasbinSuite) buildTokenRoles(admin bool, standard bool, roleMaps []string) []interface{} { @@ -557,26 +808,23 @@ func (s *AuthnCasbinSuite) buildTokenRoles(admin bool, standard bool, roleMaps [ standardRole = roleMaps[1] } - i := 0 - roles := make([]interface{}, 2) + roles := make([]interface{}, 0, 2) if admin { - roles[i] = adminRole - i++ + roles = append(roles, adminRole) } if standard { - roles[i] = standardRole + roles = append(roles, standardRole) } return roles } -func (s *AuthnCasbinSuite) newTokWithDefaultClaim(admin bool, standard bool, usernameClaimName, groupClaimName string) jwt.Token { +func (s *AuthnCasbinSuite) newTokWithDefaultClaim(admin bool, standard bool, usernameClaimName string) jwt.Token { tok := jwt.New() - if groupClaimName == "" { - groupClaimName = "roles" - } + // Always using "roles" as the group claim name + groupClaimName := "roles" tokenRoles := s.buildTokenRoles(admin, standard, nil) if err := tok.Set("realm_access", map[string]interface{}{groupClaimName: tokenRoles}); err != nil { diff --git a/service/internal/auth/config.go b/service/internal/auth/config.go index 5e48877cf1..ed09bab8a6 100644 --- a/service/internal/auth/config.go +++ b/service/internal/auth/config.go @@ -1,6 +1,7 @@ package auth import ( + "encoding/json" "errors" "time" @@ -28,15 +29,46 @@ type AuthNConfig struct { //nolint:revive // AuthNConfig is a valid name TokenSkew time.Duration `mapstructure:"skew" json:"skew" default:"1m"` } +// GroupsClaimList is a custom type to support unmarshalling from string or []string +// for backward compatibility in config files. +type GroupsClaimList []string + +func (g *GroupsClaimList) UnmarshalJSON(data []byte) error { + var single string + if err := json.Unmarshal(data, &single); err == nil { + *g = GroupsClaimList{single} + return nil + } + var multi []string + if err := json.Unmarshal(data, &multi); err == nil { + *g = GroupsClaimList(multi) + return nil + } + return errors.New("invalid groups_claim: must be string or array of strings") +} + +func (g *GroupsClaimList) UnmarshalText(text []byte) error { + s := string(text) + // Try parsing as JSON array first (e.g., '["claim1","claim2"]' from env var) + var multi []string + if err := json.Unmarshal([]byte(s), &multi); err == nil { + *g = GroupsClaimList(multi) + return nil + } + // Fallback: treat as single string value + *g = GroupsClaimList{s} + return nil +} + type PolicyConfig struct { Builtin string `mapstructure:"-" json:"-"` // Username claim to use for user information UserNameClaim string `mapstructure:"username_claim" json:"username_claim" default:"preferred_username"` - // Claim to use for group/role information - GroupsClaim string `mapstructure:"groups_claim" json:"groups_claim" default:"realm_access.roles"` + // Claims to use for group/role information (supports multiple claims) + GroupsClaim GroupsClaimList `mapstructure:"groups_claim" json:"groups_claim" default:"[\"realm_access.roles\"]"` // Claim to use to reference idP clientID ClientIDClaim string `mapstructure:"client_id_claim" json:"client_id_claim" default:"azp"` - // Deprecated: Use GroupClain instead + // Deprecated: Use GroupsClaim instead RoleClaim string `mapstructure:"claim" json:"claim" default:"realm_access.roles"` // Deprecated: Use Casbin grouping statements g, , RoleMap map[string]string `mapstructure:"map" json:"map"` diff --git a/service/internal/auth/dotnotation.go b/service/internal/auth/dotnotation.go deleted file mode 100644 index 0b5923c6f6..0000000000 --- a/service/internal/auth/dotnotation.go +++ /dev/null @@ -1,22 +0,0 @@ -package auth - -import "strings" - -// dotNotation retrieves a value from a nested map using dot notation keys. -func dotNotation(m map[string]any, key string) any { - keys := strings.Split(key, ".") - for i, k := range keys { - if i == len(keys)-1 { - return m[k] - } - if m[k] == nil { - return nil - } - var ok bool - m, ok = m[k].(map[string]any) - if !ok { - return nil - } - } - return nil -} diff --git a/service/internal/auth/dotnotation_test.go b/service/internal/auth/dotnotation_test.go deleted file mode 100644 index a40ca21eb0..0000000000 --- a/service/internal/auth/dotnotation_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package auth - -import ( - "testing" -) - -func TestDotNotation(t *testing.T) { - tests := []struct { - name string - input map[string]any - key string - expected any - }{ - {name: "valid key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.b", expected: 1}, - {name: "non-existent key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.c", expected: nil}, - {name: "nested map", input: map[string]any{"a": map[string]any{"b": map[string]any{"c": 2}}}, key: "a.b.c", expected: 2}, - {name: "invalid key type", input: map[string]any{"a": 1}, key: "a.b", expected: nil}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := dotNotation(tt.input, tt.key) - if result != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, result) - } - }) - } -} diff --git a/service/pkg/config/config.go b/service/pkg/config/config.go index b3bcdff64d..3b8d371edf 100644 --- a/service/pkg/config/config.go +++ b/service/pkg/config/config.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/go-playground/validator/v10" + "github.com/go-viper/mapstructure/v2" "github.com/opentdf/platform/service/internal/server" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" @@ -275,7 +276,8 @@ func (c *Config) Reload(ctx context.Context) error { // Unmarshal the merged configuration into the main config struct `c` // so it's available for the next iteration of the dependency loop. - if err := orderedViper.Unmarshal(c); err != nil { + // TextUnmarshallerHookFunc enables custom types with UnmarshalText to decode from strings. + if err := orderedViper.Unmarshal(c, viper.DecodeHook(mapstructure.TextUnmarshallerHookFunc())); err != nil { return errors.Join(err, ErrUnmarshallingConfig) } diff --git a/service/pkg/util/dotnotation.go b/service/pkg/util/dotnotation.go new file mode 100644 index 0000000000..737dda4cd6 --- /dev/null +++ b/service/pkg/util/dotnotation.go @@ -0,0 +1,37 @@ +package util + +import "strings" + +// Dotnotation retrieves a value from a nested map using dot notation keys. +// Returns nil for empty keys, malformed paths (leading/trailing/double dots), +// or if the path doesn't exist in the map. +func Dotnotation(m map[string]interface{}, key string) interface{} { + if key == "" { + return nil + } + keys := strings.Split(key, ".") + // Filter out empty segments from leading/trailing/double dots + filtered := keys[:0] + for _, k := range keys { + if k != "" { + filtered = append(filtered, k) + } + } + if len(filtered) == 0 { + return nil + } + for i, k := range filtered { + if i == len(filtered)-1 { + return m[k] + } + if m[k] == nil { + return nil + } + var ok bool + m, ok = m[k].(map[string]interface{}) + if !ok { + return nil + } + } + return nil +} diff --git a/service/pkg/util/dotnotation_test.go b/service/pkg/util/dotnotation_test.go new file mode 100644 index 0000000000..fa0b4855fa --- /dev/null +++ b/service/pkg/util/dotnotation_test.go @@ -0,0 +1,38 @@ +package util + +import ( + "testing" +) + +func TestDotnotation(t *testing.T) { + tests := []struct { + name string + input map[string]any + key string + expected any + }{ + // Basic cases + {name: "valid key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.b", expected: 1}, + {name: "non-existent key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.c", expected: nil}, + {name: "nested map", input: map[string]any{"a": map[string]any{"b": map[string]any{"c": 2}}}, key: "a.b.c", expected: 2}, + {name: "invalid key type", input: map[string]any{"a": 1}, key: "a.b", expected: nil}, + {name: "top level key", input: map[string]any{"a": "value"}, key: "a", expected: "value"}, + {name: "nil map value", input: map[string]any{"a": nil}, key: "a.b", expected: nil}, + // Edge cases for malformed keys + {name: "empty key", input: map[string]any{"a": 1}, key: "", expected: nil}, + {name: "trailing dot", input: map[string]any{"a": 1}, key: "a.", expected: 1}, + {name: "leading dot", input: map[string]any{"a": 1}, key: ".a", expected: 1}, + {name: "double dot", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a..b", expected: 1}, + {name: "only dots", input: map[string]any{"a": 1}, key: "...", expected: nil}, + {name: "whitespace key", input: map[string]any{" ": 1}, key: " ", expected: 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Dotnotation(tt.input, tt.key) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +}