Skip to content

Commit

Permalink
add custom claims to user groups
Browse files Browse the repository at this point in the history
  • Loading branch information
stonith404 committed Oct 28, 2024
1 parent d5ceb94 commit cb27e8c
Show file tree
Hide file tree
Showing 18 changed files with 225 additions and 54 deletions.
10 changes: 10 additions & 0 deletions backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,13 @@ func (e *ReservedClaimError) Error() string {
}

func (e *ReservedClaimError) HttpStatusCode() int { return http.StatusBadRequest }

type DuplicateClaimError struct {
Key string
}

func (e *DuplicateClaimError) Error() string {
return fmt.Sprintf("Claim %s is already defined", e.Key)
}

func (e *DuplicateClaimError) HttpStatusCode() int { return http.StatusBadRequest }
31 changes: 28 additions & 3 deletions backend/internal/controller/custom_claim_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) {
wkc := &CustomClaimController{customClaimService: customClaimService}
group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler)
group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.updateUserCustomClaimsHandler)
group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserHandler)
group.PUT("/custom-claims/user-group/:userGroupId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserGroupHandler)
}

type CustomClaimController struct {
Expand All @@ -28,7 +29,7 @@ func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
c.JSON(http.StatusOK, claims)
}

func (ccc *CustomClaimController) updateUserCustomClaimsHandler(c *gin.Context) {
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto

if err := c.ShouldBindJSON(&input); err != nil {
Expand All @@ -37,7 +38,31 @@ func (ccc *CustomClaimController) updateUserCustomClaimsHandler(c *gin.Context)
}

userId := c.Param("userId")
claims, err := ccc.customClaimService.UpdateUserCustomClaims(userId, input)
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input)
if err != nil {
c.Error(err)
return
}

var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err)
return
}

c.JSON(http.StatusOK, customClaimsDto)
}

func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto

if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}

userId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userId, input)
if err != nil {
c.Error(err)
return
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/controller/test_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
}

if err := tc.TestService.ResetAppConfig(); err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

Expand Down
22 changes: 12 additions & 10 deletions backend/internal/dto/user_group_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@ package dto
import "time"

type UserGroupDtoWithUsers struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
Users []UserDto `json:"users"`
CreatedAt time.Time `json:"createdAt"`
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
Users []UserDto `json:"users"`
CreatedAt time.Time `json:"createdAt"`
}

type UserGroupDtoWithUserCount struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
UserCount int64 `json:"userCount"`
CreatedAt time.Time `json:"createdAt"`
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
UserCount int64 `json:"userCount"`
CreatedAt time.Time `json:"createdAt"`
}

type UserGroupCreateDto struct {
Expand Down
3 changes: 2 additions & 1 deletion backend/internal/model/custom_claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ type CustomClaim struct {
Key string
Value string

UserID string
UserID *string
UserGroupID *string
}
1 change: 1 addition & 0 deletions backend/internal/model/user_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ type UserGroup struct {
FriendlyName string
Name string `gorm:"unique"`
Users []User `gorm:"many2many:user_groups_users;"`
CustomClaims []CustomClaim
}
102 changes: 94 additions & 8 deletions backend/internal/service/custom_claim_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,37 @@ func isReservedClaim(key string) bool {
return ok
}

func (s *CustomClaimService) UpdateUserCustomClaims(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
// idType is the type of the id used to identify the user or user group
type idType string

const (
UserID idType = "user_id"
UserGroupID idType = "user_group_id"
)

// UpdateCustomClaimsForUser updates the custom claims for a user
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserID, userID, claims)
}

// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
}

// updateCustomClaims updates the custom claims for a user or user group
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
// Check for duplicate keys in the claims slice
seenKeys := make(map[string]bool)
for _, claim := range claims {
if seenKeys[claim.Key] {
return nil, &common.DuplicateClaimError{Key: claim.Key}
}
seenKeys[claim.Key] = true
}

var existingClaims []model.CustomClaim
err := s.db.Where("user_id = ?", userID).Find(&existingClaims).Error
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
if err != nil {
return nil, err
}
Expand All @@ -72,20 +100,27 @@ func (s *CustomClaimService) UpdateUserCustomClaims(userID string, claims []dto.
if isReservedClaim(claim.Key) {
return nil, &common.ReservedClaimError{Key: claim.Key}
}
customClaim := model.CustomClaim{
Key: claim.Key,
Value: claim.Value,
}

if idType == UserID {
customClaim.UserID = &value
} else if idType == UserGroupID {
customClaim.UserGroupID = &value
}

// Update the claim if it already exists or create a new one
err = s.db.Where("user_id = ? AND key = ?", userID, claim.Key).Assign(&model.CustomClaim{
Key: claim.Key,
Value: claim.Value,
UserID: userID,
}).FirstOrCreate(&model.CustomClaim{}).Error
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
if err != nil {
return nil, err
}
}

// Get the updated claims
var updatedClaims []model.CustomClaim
err = s.db.Where("user_id = ?", userID).Find(&updatedClaims).Error
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
if err != nil {
return nil, err
}
Expand All @@ -99,6 +134,57 @@ func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.Cust
return customClaims, err
}

func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
return customClaims, err
}

// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
// prioritizing the user's claims over user group claims with the same key.
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
// Get the custom claims of the user
customClaims, err := s.GetCustomClaimsForUser(userID)
if err != nil {
return nil, err
}

// Store user's claims in a map to prioritize and prevent duplicates
claimsMap := make(map[string]model.CustomClaim)
for _, claim := range customClaims {
claimsMap[claim.Key] = claim
}

// Get all user groups of the user
var userGroupsOfUser []model.UserGroup
err = s.db.Preload("CustomClaims").
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
Where("user_groups_users.user_id = ?", userID).
Find(&userGroupsOfUser).Error
if err != nil {
return nil, err
}

// Add only non-duplicate custom claims from user groups
for _, userGroup := range userGroupsOfUser {
for _, groupClaim := range userGroup.CustomClaims {
// Only add claim if it does not exist in the user's claims
if _, exists := claimsMap[groupClaim.Key]; !exists {
claimsMap[groupClaim.Key] = groupClaim
}
}
}

// Convert the claimsMap back to a slice
finalClaims := make([]model.CustomClaim, 0, len(claimsMap))
for _, claim := range claimsMap {
finalClaims = append(finalClaims, claim)
}

return finalClaims, nil
}

// GetSuggestions returns a list of custom claim keys that have been used before
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
var customClaimsKeys []string

Expand Down
2 changes: 1 addition & 1 deletion backend/internal/service/oidc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
}

// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUser(userID)
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions backend/internal/service/user_group_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func NewUserGroupService(db *gorm.DB) *UserGroupService {
}

func (s *UserGroupService) List(name string, page int, pageSize int) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Model(&model.UserGroup{})
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{})

if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%")
Expand All @@ -29,7 +29,7 @@ func (s *UserGroupService) List(name string, page int, pageSize int) (groups []m
}

func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("Users").First(&group).Error
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error
return group, err
}

Expand Down
12 changes: 0 additions & 12 deletions backend/migrations/20241024064959_custom_claims.up.sql

This file was deleted.

15 changes: 15 additions & 0 deletions backend/migrations/20241028064959_custom_claims.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
CREATE TABLE custom_claims
(
id TEXT NOT NULL PRIMARY KEY,
created_at DATETIME,
key TEXT NOT NULL,
value TEXT NOT NULL,

user_id TEXT,
user_group_id TEXT,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE,

CONSTRAINT custom_claims_unique UNIQUE (key, user_id, user_group_id),
CHECK (user_id IS NOT NULL OR user_group_id IS NOT NULL)
);
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,43 @@
let {
value = $bindable(''),
placeholder,
suggestionLimit = 5,
suggestions
}: { value: string; placeholder: string; suggestions: string[] } = $props();
}: {
value: string;
placeholder: string;
suggestionLimit?: number;
suggestions: string[];
} = $props();
let suggestionResult: string[] = $state(suggestions);
let filteredSuggestions: string[] = $state(suggestions.slice(0, suggestionLimit));
let selectedIndex = $state(-1);
let isInputFocused = $state(false);
function handleSuggestionClick(suggestion: (typeof suggestions)[0]) {
value = suggestion;
suggestionResult = [];
filteredSuggestions = [];
}
function handleOnInput() {
suggestionResult = suggestions.filter((s) => s.includes(value.toLowerCase()));
filteredSuggestions = suggestions
.filter((s) => s.includes(value.toLowerCase()))
.slice(0, suggestionLimit);
}
function handleKeydown(e: KeyboardEvent) {
if (!isOpen) return;
switch (e.key) {
case 'ArrowDown':
selectedIndex = Math.min(selectedIndex + 1, suggestionResult.length - 1);
selectedIndex = Math.min(selectedIndex + 1, filteredSuggestions.length - 1);
break;
case 'ArrowUp':
selectedIndex = Math.max(selectedIndex - 1, -1);
break;
case 'Enter':
if (selectedIndex >= 0) {
handleSuggestionClick(suggestionResult[selectedIndex]);
handleSuggestionClick(filteredSuggestions[selectedIndex]);
}
break;
case 'Escape':
Expand All @@ -42,18 +50,18 @@
}
}
let isOpen = $derived(suggestionResult.length > 0 && isInputFocused);
let isOpen = $derived(filteredSuggestions.length > 0 && isInputFocused);
$effect(() => {
// Reset selection when suggestions change
if (suggestionResult) {
if (filteredSuggestions) {
selectedIndex = -1;
}
});
</script>

<div
class="w-full grid"
class="grid w-full"
role="combobox"
onkeydown={handleKeydown}
aria-controls="suggestion-list"
Expand All @@ -74,9 +82,9 @@
closeOnOutsideClick={false}
closeOnEscape={false}
>
<Popover.Trigger tabindex={-1} class="w-full h-0" aria-hidden />
<Popover.Trigger tabindex={-1} class="h-0 w-full" aria-hidden />
<Popover.Content class="p-0" sideOffset={5} sameWidth>
{#each suggestionResult as suggestion, index}
{#each filteredSuggestions as suggestion, index}
<div
role="button"
tabindex="0"
Expand Down
Loading

0 comments on commit cb27e8c

Please sign in to comment.