From c056089c6043a825aaaaecf0c57454892a108f1d Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Mon, 28 Oct 2024 18:11:54 +0100 Subject: [PATCH] feat: custom claims (#53) --- .../internal/bootstrap/router_bootstrap.go | 5 +- backend/internal/common/errors.go | 159 ++++++++++++-- .../controller/app_config_controller.go | 27 +-- .../controller/audit_log_controller.go | 5 +- .../controller/custom_claim_controller.go | 78 +++++++ .../internal/controller/oidc_controller.go | 73 +++---- .../internal/controller/test_controller.go | 9 +- .../internal/controller/user_controller.go | 57 ++--- .../controller/user_group_controller.go | 41 ++-- .../controller/webauthn_controller.go | 38 ++-- .../controller/well_known_controller.go | 3 +- backend/internal/dto/custom_claim_dto.go | 8 + backend/internal/dto/user_dto.go | 13 +- backend/internal/dto/user_group_dto.go | 22 +- .../error_handler.go} | 74 ++++--- .../internal/middleware/file_size_limit.go | 6 +- backend/internal/middleware/jwt_auth.go | 10 +- backend/internal/middleware/rate_limit.go | 5 +- backend/internal/model/custom_claim.go | 11 + backend/internal/model/user.go | 5 +- backend/internal/model/user_group.go | 1 + .../internal/service/app_config_service.go | 2 +- .../internal/service/custom_claim_service.go | 197 ++++++++++++++++++ backend/internal/service/oidc_service.go | 47 +++-- .../internal/service/user_group_service.go | 8 +- backend/internal/service/user_service.go | 12 +- .../20241028064959_custom_claims.down.sql | 1 + .../20241028064959_custom_claims.up.sql | 15 ++ frontend/package-lock.json | 4 +- .../lib/components/auto-complete-input.svelte | 105 ++++++++++ .../lib/components/custom-claims-input.svelte | 75 +++++++ frontend/src/lib/components/form-input.svelte | 10 +- .../src/lib/components/ui/popover/index.ts | 17 ++ .../ui/popover/popover-content.svelte | 22 ++ .../src/lib/services/custom-claim-service.ts | 19 ++ frontend/src/lib/types/custom-claim.type.ts | 4 + frontend/src/lib/types/user-group.type.ts | 2 + frontend/src/lib/types/user.type.ts | 5 +- .../admin/user-groups/[id]/+page.svelte | 31 ++- .../settings/admin/users/[id]/+page.svelte | 32 ++- frontend/tests/account-settings.spec.ts | 4 +- frontend/tests/user-group.spec.ts | 36 ++++ frontend/tests/user-settings.spec.ts | 50 ++++- 43 files changed, 1069 insertions(+), 279 deletions(-) create mode 100644 backend/internal/controller/custom_claim_controller.go create mode 100644 backend/internal/dto/custom_claim_dto.go rename backend/internal/{utils/controller_error_util.go => middleware/error_handler.go} (51%) create mode 100644 backend/internal/model/custom_claim.go create mode 100644 backend/internal/service/custom_claim_service.go create mode 100644 backend/migrations/20241028064959_custom_claims.down.sql create mode 100644 backend/migrations/20241028064959_custom_claims.up.sql create mode 100644 frontend/src/lib/components/auto-complete-input.svelte create mode 100644 frontend/src/lib/components/custom-claims-input.svelte create mode 100644 frontend/src/lib/components/ui/popover/index.ts create mode 100644 frontend/src/lib/components/ui/popover/popover-content.svelte create mode 100644 frontend/src/lib/services/custom-claim-service.ts create mode 100644 frontend/src/lib/types/custom-claim.type.ts diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 57d1fda..1a07a39 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -39,11 +39,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { jwtService := service.NewJwtService(appConfigService) webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) userService := service.NewUserService(db, jwtService) - oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService) + customClaimService := service.NewCustomClaimService(db) + oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) testService := service.NewTestService(db, appConfigService) userGroupService := service.NewUserGroupService(db) r.Use(middleware.NewCorsMiddleware().Add()) + r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60)) r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false)) @@ -59,6 +61,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService) controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware) controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService) + controller.NewCustomClaimController(apiGroup, jwtAuthMiddleware, customClaimService) // Add test controller in non-production environments if common.EnvConfig.AppEnv != "production" { diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index d05ac13..0056ce7 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -1,19 +1,146 @@ package common -import "errors" - -var ( - ErrUsernameTaken = errors.New("username is already taken") - ErrEmailTaken = errors.New("email is already taken") - ErrSetupAlreadyCompleted = errors.New("setup already completed") - ErrTokenInvalidOrExpired = errors.New("token is invalid or expired") - ErrOidcMissingAuthorization = errors.New("missing authorization") - ErrOidcGrantTypeNotSupported = errors.New("grant type not supported") - ErrOidcMissingClientCredentials = errors.New("client id or secret not provided") - ErrOidcClientSecretInvalid = errors.New("invalid client secret") - ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code") - ErrOidcInvalidCallbackURL = errors.New("invalid callback URL") - ErrFileTypeNotSupported = errors.New("file type not supported") - ErrInvalidCredentials = errors.New("no user found with provided credentials") - ErrNameAlreadyInUse = errors.New("name is already in use") +import ( + "fmt" + "net/http" ) + +type AppError interface { + error + HttpStatusCode() int +} + +// Custom error types for various conditions + +type AlreadyInUseError struct { + Property string +} + +func (e *AlreadyInUseError) Error() string { + return fmt.Sprintf("%s is already in use", e.Property) +} +func (e *AlreadyInUseError) HttpStatusCode() int { return 400 } + +type SetupAlreadyCompletedError struct{} + +func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" } +func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return 400 } + +type TokenInvalidOrExpiredError struct{} + +func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" } +func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 } + +type OidcMissingAuthorizationError struct{} + +func (e *OidcMissingAuthorizationError) Error() string { return "missing authorization" } +func (e *OidcMissingAuthorizationError) HttpStatusCode() int { return http.StatusForbidden } + +type OidcGrantTypeNotSupportedError struct{} + +func (e *OidcGrantTypeNotSupportedError) Error() string { return "grant type not supported" } +func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return 400 } + +type OidcMissingClientCredentialsError struct{} + +func (e *OidcMissingClientCredentialsError) Error() string { return "client id or secret not provided" } +func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return 400 } + +type OidcClientSecretInvalidError struct{} + +func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" } +func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 } + +type OidcInvalidAuthorizationCodeError struct{} + +func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" } +func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 } + +type OidcInvalidCallbackURLError struct{} + +func (e *OidcInvalidCallbackURLError) Error() string { return "invalid callback URL" } +func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 } + +type FileTypeNotSupportedError struct{} + +func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" } +func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 } + +type InvalidCredentialsError struct{} + +func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" } +func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 } + +type FileTooLargeError struct { + MaxSize string +} + +func (e *FileTooLargeError) Error() string { + return fmt.Sprintf("The file can't be larger than %s", e.MaxSize) +} +func (e *FileTooLargeError) HttpStatusCode() int { return http.StatusRequestEntityTooLarge } + +type NotSignedInError struct{} + +func (e *NotSignedInError) Error() string { return "You are not signed in" } +func (e *NotSignedInError) HttpStatusCode() int { return http.StatusUnauthorized } + +type MissingPermissionError struct{} + +func (e *MissingPermissionError) Error() string { + return "You don't have permission to perform this action" +} +func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbidden } + +type TooManyRequestsError struct{} + +func (e *TooManyRequestsError) Error() string { + return "Too many requests. Please wait a while before trying again." +} +func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests } + +type ClientIdOrSecretNotProvidedError struct{} + +func (e *ClientIdOrSecretNotProvidedError) Error() string { + return "Client id and secret not provided" +} + +func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest } + +type WrongFileTypeError struct { + ExpectedFileType string +} + +func (e *WrongFileTypeError) Error() string { + return fmt.Sprintf("File must be of type %s", e.ExpectedFileType) +} + +func (e *WrongFileTypeError) HttpStatusCode() int { return http.StatusBadRequest } + +type MissingSessionIdError struct{} + +func (e *MissingSessionIdError) Error() string { + return "Missing session id" +} + +func (e *MissingSessionIdError) HttpStatusCode() int { return http.StatusBadRequest } + +type ReservedClaimError struct { + Key string +} + +func (e *ReservedClaimError) Error() string { + return fmt.Sprintf("Claim %s is reserved and can't be used", e.Key) +} + +func (e *ReservedClaimError) HttpStatusCode() int { return http.StatusBadRequest } + +type DuplicateClaimError struct { + Key string +} + +func (e *DuplicateClaimError) Error() string { + return fmt.Sprintf("Claim %s is already defined", e.Key) +} + +func (e *DuplicateClaimError) HttpStatusCode() int { return http.StatusBadRequest } diff --git a/backend/internal/controller/app_config_controller.go b/backend/internal/controller/app_config_controller.go index c6ee106..5bed631 100644 --- a/backend/internal/controller/app_config_controller.go +++ b/backend/internal/controller/app_config_controller.go @@ -1,7 +1,6 @@ package controller import ( - "errors" "fmt" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" @@ -39,13 +38,13 @@ type AppConfigController struct { func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { configuration, err := acc.appConfigService.ListAppConfig(false) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var configVariablesDto []dto.PublicAppConfigVariableDto if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -55,13 +54,13 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { configuration, err := acc.appConfigService.ListAppConfig(true) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var configVariablesDto []dto.AppConfigVariableDto if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -71,19 +70,19 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) { var input dto.AppConfigUpdateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var configVariablesDto []dto.AppConfigVariableDto if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -136,13 +135,13 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) { func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) { file, err := c.FormFile("file") if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } fileType := utils.GetFileExtension(file.Filename) if fileType != "ico" { - utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico") + c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"}) return } acc.updateImage(c, "favicon", "ico") @@ -164,17 +163,13 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) { file, err := c.FormFile("file") if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) if err != nil { - if errors.Is(err, common.ErrFileTypeNotSupported) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } diff --git a/backend/internal/controller/audit_log_controller.go b/backend/internal/controller/audit_log_controller.go index e6e1616..a032645 100644 --- a/backend/internal/controller/audit_log_controller.go +++ b/backend/internal/controller/audit_log_controller.go @@ -8,7 +8,6 @@ import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" ) func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.AuditLogService, jwtAuthMiddleware *middleware.JwtAuthMiddleware) { @@ -31,7 +30,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { // Fetch audit logs for the user logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -39,7 +38,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { var logsDtos []dto.AuditLogDto err = dto.MapStructList(logs, &logsDtos) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/custom_claim_controller.go b/backend/internal/controller/custom_claim_controller.go new file mode 100644 index 0000000..ca28261 --- /dev/null +++ b/backend/internal/controller/custom_claim_controller.go @@ -0,0 +1,78 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/dto" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/service" + "net/http" +) + +func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) { + wkc := &CustomClaimController{customClaimService: customClaimService} + group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler) + group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserHandler) + group.PUT("/custom-claims/user-group/:userGroupId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserGroupHandler) +} + +type CustomClaimController struct { + customClaimService *service.CustomClaimService +} + +func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) { + claims, err := ccc.customClaimService.GetSuggestions() + if err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, claims) +} + +func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) { + var input []dto.CustomClaimCreateDto + + if err := c.ShouldBindJSON(&input); err != nil { + c.Error(err) + return + } + + userId := c.Param("userId") + claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input) + if err != nil { + c.Error(err) + return + } + + var customClaimsDto []dto.CustomClaimDto + if err := dto.MapStructList(claims, &customClaimsDto); err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, customClaimsDto) +} + +func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) { + var input []dto.CustomClaimCreateDto + + if err := c.ShouldBindJSON(&input); err != nil { + c.Error(err) + return + } + + userId := c.Param("userGroupId") + claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userId, input) + if err != nil { + c.Error(err) + return + } + + var customClaimsDto []dto.CustomClaimDto + if err := dto.MapStructList(claims, &customClaimsDto); err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, customClaimsDto) +} diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 8982c8a..49934cb 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -1,13 +1,11 @@ package controller import ( - "errors" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" "strconv" "strings" @@ -42,19 +40,13 @@ type OidcController struct { func (oc *OidcController) authorizeHandler(c *gin.Context) { var input dto.AuthorizeOidcClientRequestDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) if err != nil { - if errors.Is(err, common.ErrOidcMissingAuthorization) { - utils.CustomControllerError(c, http.StatusForbidden, err.Error()) - } else if errors.Is(err, common.ErrOidcInvalidCallbackURL) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -69,17 +61,13 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { var input dto.AuthorizeOidcClientRequestDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) if err != nil { - if errors.Is(err, common.ErrOidcInvalidCallbackURL) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -95,7 +83,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { var input dto.OidcIdTokenDto if err := c.ShouldBind(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -107,21 +95,14 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { var ok bool clientID, clientSecret, ok = c.Request.BasicAuth() if !ok { - utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided") + c.Error(&common.ClientIdOrSecretNotProvidedError{}) return } } idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret) if err != nil { - if errors.Is(err, common.ErrOidcGrantTypeNotSupported) || - errors.Is(err, common.ErrOidcMissingClientCredentials) || - errors.Is(err, common.ErrOidcClientSecretInvalid) || - errors.Is(err, common.ErrOidcInvalidAuthorizationCode) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -132,14 +113,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) { token := strings.Split(c.GetHeader("Authorization"), " ")[1] jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token) if err != nil { - utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error()) + c.Error(err) return } userID := jwtClaims.Subject clientId := jwtClaims.Audience[0] claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -150,7 +131,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) { clientId := c.Param("id") client, err := oc.oidcService.GetClient(clientId) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -171,7 +152,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) { } } - utils.ControllerError(c, err) + c.Error(err) } func (oc *OidcController) listClientsHandler(c *gin.Context) { @@ -181,13 +162,13 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var clientsDto []dto.OidcClientDto if err := dto.MapStructList(clients, &clientsDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -200,19 +181,19 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { func (oc *OidcController) createClientHandler(c *gin.Context) { var input dto.OidcClientCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var clientDto dto.OidcClientDto if err := dto.MapStruct(client, &clientDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -222,7 +203,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) { func (oc *OidcController) deleteClientHandler(c *gin.Context) { err := oc.oidcService.DeleteClient(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -232,19 +213,19 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) { func (oc *OidcController) updateClientHandler(c *gin.Context) { var input dto.OidcClientCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } client, err := oc.oidcService.UpdateClient(c.Param("id"), input) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var clientDto dto.OidcClientDto if err := dto.MapStruct(client, &clientDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -254,7 +235,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) { func (oc *OidcController) createClientSecretHandler(c *gin.Context) { secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -264,7 +245,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) { func (oc *OidcController) getClientLogoHandler(c *gin.Context) { imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -275,17 +256,13 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) { func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { file, err := c.FormFile("file") if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) if err != nil { - if errors.Is(err, common.ErrFileTypeNotSupported) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -295,7 +272,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { err := oc.oidcService.DeleteClientLogo(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/test_controller.go b/backend/internal/controller/test_controller.go index e526d84..a613e4c 100644 --- a/backend/internal/controller/test_controller.go +++ b/backend/internal/controller/test_controller.go @@ -3,7 +3,6 @@ package controller import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) @@ -19,22 +18,22 @@ type TestController struct { func (tc *TestController) resetAndSeedHandler(c *gin.Context) { if err := tc.TestService.ResetDatabase(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } if err := tc.TestService.ResetApplicationImages(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } if err := tc.TestService.SeedDatabase(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } if err := tc.TestService.ResetAppConfig(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index 2d9760c..d68daf4 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -1,13 +1,10 @@ package controller import ( - "errors" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "golang.org/x/time/rate" "net/http" "strconv" @@ -43,13 +40,13 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var usersDto []dto.UserDto if err := dto.MapStructList(users, &usersDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -62,13 +59,13 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { func (uc *UserController) getUserHandler(c *gin.Context) { user, err := uc.UserService.GetUser(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -78,13 +75,13 @@ func (uc *UserController) getUserHandler(c *gin.Context) { func (uc *UserController) getCurrentUserHandler(c *gin.Context) { user, err := uc.UserService.GetUser(c.GetString("userID")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -93,7 +90,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) { func (uc *UserController) deleteUserHandler(c *gin.Context) { if err := uc.UserService.DeleteUser(c.Param("id")); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -103,23 +100,19 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) { func (uc *UserController) createUserHandler(c *gin.Context) { var input dto.UserCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } user, err := uc.UserService.CreateUser(input) if err != nil { - if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -137,13 +130,13 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { var input dto.OneTimeAccessTokenCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -153,17 +146,13 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token")) if err != nil { - if errors.Is(err, common.ErrTokenInvalidOrExpired) { - utils.CustomControllerError(c, http.StatusUnauthorized, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -174,17 +163,13 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { user, token, err := uc.UserService.SetupInitialAdmin() if err != nil { - if errors.Is(err, common.ErrSetupAlreadyCompleted) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -195,7 +180,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { var input dto.UserCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -208,17 +193,13 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser) if err != nil { - if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/user_group_controller.go b/backend/internal/controller/user_group_controller.go index 0012f80..e7fcbb3 100644 --- a/backend/internal/controller/user_group_controller.go +++ b/backend/internal/controller/user_group_controller.go @@ -1,16 +1,13 @@ package controller import ( - "errors" "net/http" "strconv" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" ) func NewUserGroupController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, userGroupService *service.UserGroupService) { @@ -37,7 +34,7 @@ func (ugc *UserGroupController) list(c *gin.Context) { groups, pagination, err := ugc.UserGroupService.List(searchTerm, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -45,12 +42,12 @@ func (ugc *UserGroupController) list(c *gin.Context) { for i, group := range groups { var groupDto dto.UserGroupDtoWithUserCount if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } groupsDto[i] = groupDto @@ -65,13 +62,13 @@ func (ugc *UserGroupController) list(c *gin.Context) { func (ugc *UserGroupController) get(c *gin.Context) { group, err := ugc.UserGroupService.Get(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -81,23 +78,19 @@ func (ugc *UserGroupController) get(c *gin.Context) { func (ugc *UserGroupController) create(c *gin.Context) { var input dto.UserGroupCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } group, err := ugc.UserGroupService.Create(input) if err != nil { - if errors.Is(err, common.ErrNameAlreadyInUse) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -107,23 +100,19 @@ func (ugc *UserGroupController) create(c *gin.Context) { func (ugc *UserGroupController) update(c *gin.Context) { var input dto.UserGroupCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } group, err := ugc.UserGroupService.Update(c.Param("id"), input) if err != nil { - if errors.Is(err, common.ErrNameAlreadyInUse) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -132,7 +121,7 @@ func (ugc *UserGroupController) update(c *gin.Context) { func (ugc *UserGroupController) delete(c *gin.Context) { if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -142,19 +131,19 @@ func (ugc *UserGroupController) delete(c *gin.Context) { func (ugc *UserGroupController) updateUsers(c *gin.Context) { var input dto.UserGroupUpdateUsersDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go index e0c6bc1..bff2c29 100644 --- a/backend/internal/controller/webauthn_controller.go +++ b/backend/internal/controller/webauthn_controller.go @@ -1,17 +1,15 @@ package controller import ( - "errors" "github.com/go-webauthn/webauthn/protocol" + "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "net/http" "time" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "golang.org/x/time/rate" ) @@ -38,7 +36,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { userID := c.GetString("userID") options, err := wc.webAuthnService.BeginRegistration(userID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -49,20 +47,20 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { sessionID, err := c.Cookie("session_id") if err != nil { - utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing") + c.Error(&common.MissingSessionIdError{}) return } userID := c.GetString("userID") credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var credentialDto dto.WebauthnCredentialDto if err := dto.MapStruct(credential, &credentialDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -72,7 +70,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { options, err := wc.webAuthnService.BeginLogin() if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -83,13 +81,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { sessionID, err := c.Cookie("session_id") if err != nil { - utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing") + c.Error(&common.MissingSessionIdError{}) return } credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -97,17 +95,13 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { user, token, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent()) if err != nil { - if errors.Is(err, common.ErrInvalidCredentials) { - utils.CustomControllerError(c, http.StatusUnauthorized, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -119,13 +113,13 @@ func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { userID := c.GetString("userID") credentials, err := wc.webAuthnService.ListCredentials(userID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var credentialDtos []dto.WebauthnCredentialDto if err := dto.MapStructList(credentials, &credentialDtos); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -138,7 +132,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { err := wc.webAuthnService.DeleteCredential(userID, credentialID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -151,19 +145,19 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) { var input dto.WebauthnCredentialUpdateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var credentialDto dto.WebauthnCredentialDto if err := dto.MapStruct(credential, &credentialDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/well_known_controller.go b/backend/internal/controller/well_known_controller.go index 32104cd..4e0242a 100644 --- a/backend/internal/controller/well_known_controller.go +++ b/backend/internal/controller/well_known_controller.go @@ -4,7 +4,6 @@ import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) @@ -21,7 +20,7 @@ type WellKnownController struct { func (wkc *WellKnownController) jwksHandler(c *gin.Context) { jwk, err := wkc.jwtService.GetJWK() if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/dto/custom_claim_dto.go b/backend/internal/dto/custom_claim_dto.go new file mode 100644 index 0000000..f80f9a6 --- /dev/null +++ b/backend/internal/dto/custom_claim_dto.go @@ -0,0 +1,8 @@ +package dto + +type CustomClaimDto struct { + Key string `json:"key" binding:"required,max=20"` + Value string `json:"value" binding:"required,max=10000"` +} + +type CustomClaimCreateDto = CustomClaimDto diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 1f34001..19af71d 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -3,12 +3,13 @@ package dto import "time" type UserDto struct { - ID string `json:"id"` - Username string `json:"username"` - Email string `json:"email" ` - FirstName string `json:"firstName"` - LastName string `json:"lastName"` - IsAdmin bool `json:"isAdmin"` + ID string `json:"id"` + Username string `json:"username"` + Email string `json:"email" ` + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + IsAdmin bool `json:"isAdmin"` + CustomClaims []CustomClaimDto `json:"customClaims"` } type UserCreateDto struct { diff --git a/backend/internal/dto/user_group_dto.go b/backend/internal/dto/user_group_dto.go index 424c61c..daef04b 100644 --- a/backend/internal/dto/user_group_dto.go +++ b/backend/internal/dto/user_group_dto.go @@ -3,19 +3,21 @@ package dto import "time" type UserGroupDtoWithUsers struct { - ID string `json:"id"` - FriendlyName string `json:"friendlyName"` - Name string `json:"name"` - Users []UserDto `json:"users"` - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + FriendlyName string `json:"friendlyName"` + Name string `json:"name"` + CustomClaims []CustomClaimDto `json:"customClaims"` + Users []UserDto `json:"users"` + CreatedAt time.Time `json:"createdAt"` } type UserGroupDtoWithUserCount struct { - ID string `json:"id"` - FriendlyName string `json:"friendlyName"` - Name string `json:"name"` - UserCount int64 `json:"userCount"` - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + FriendlyName string `json:"friendlyName"` + Name string `json:"name"` + CustomClaims []CustomClaimDto `json:"customClaims"` + UserCount int64 `json:"userCount"` + CreatedAt time.Time `json:"createdAt"` } type UserGroupCreateDto struct { diff --git a/backend/internal/utils/controller_error_util.go b/backend/internal/middleware/error_handler.go similarity index 51% rename from backend/internal/utils/controller_error_util.go rename to backend/internal/middleware/error_handler.go index 7d62ce7..e02361a 100644 --- a/backend/internal/utils/controller_error_util.go +++ b/backend/internal/middleware/error_handler.go @@ -1,37 +1,69 @@ -package utils +package middleware import ( "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" + "github.com/stonith404/pocket-id/backend/internal/common" "gorm.io/gorm" "log" "net/http" "strings" ) -import ( - "fmt" -) +type ErrorHandlerMiddleware struct{} -func ControllerError(c *gin.Context, err error) { - // Check for record not found errors - if errors.Is(err, gorm.ErrRecordNotFound) { - CustomControllerError(c, http.StatusNotFound, "Record not found") - return - } +func NewErrorHandlerMiddleware() *ErrorHandlerMiddleware { + return &ErrorHandlerMiddleware{} +} + +func (m *ErrorHandlerMiddleware) Add() gin.HandlerFunc { + return func(c *gin.Context) { + c.Next() + for _, err := range c.Errors { + + // Check for record not found errors + if errors.Is(err, gorm.ErrRecordNotFound) { + errorResponse(c, http.StatusNotFound, "Record not found") + return + } + + // Check for validation errors + var validationErrors validator.ValidationErrors + if errors.As(err, &validationErrors) { + message := handleValidationError(validationErrors) + errorResponse(c, http.StatusBadRequest, message) + return + } + + // Check for slice validation errors + var sliceValidationErrors binding.SliceValidationError + if errors.As(err, &sliceValidationErrors) { + if errors.As(sliceValidationErrors[0], &validationErrors) { + message := handleValidationError(validationErrors) + errorResponse(c, http.StatusBadRequest, message) + return + } + } - // Check for validation errors - var validationErrors validator.ValidationErrors - if errors.As(err, &validationErrors) { - message := handleValidationError(validationErrors) - CustomControllerError(c, http.StatusBadRequest, message) - return + var appErr common.AppError + if errors.As(err, &appErr) { + errorResponse(c, appErr.HttpStatusCode(), appErr.Error()) + return + } + log.Println(err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) + } } +} - log.Println(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) +func errorResponse(c *gin.Context, statusCode int, message string) { + // Capitalize the first letter of the message + message = strings.ToUpper(message[:1]) + message[1:] + c.JSON(statusCode, gin.H{"error": message}) } func handleValidationError(validationErrors validator.ValidationErrors) string { @@ -67,9 +99,3 @@ func handleValidationError(validationErrors validator.ValidationErrors) string { return combinedErrors } - -func CustomControllerError(c *gin.Context, statusCode int, message string) { - // Capitalize the first letter of the message - message = strings.ToUpper(message[:1]) + message[1:] - c.JSON(statusCode, gin.H{"error": message}) -} diff --git a/backend/internal/middleware/file_size_limit.go b/backend/internal/middleware/file_size_limit.go index 7503acb..32c7363 100644 --- a/backend/internal/middleware/file_size_limit.go +++ b/backend/internal/middleware/file_size_limit.go @@ -3,7 +3,7 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/common" "net/http" ) @@ -17,8 +17,8 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc { return func(c *gin.Context) { c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) if err := c.Request.ParseMultipartForm(maxSize); err != nil { - utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize))) - c.Abort() + err = &common.FileTooLargeError{MaxSize: formatFileSize(maxSize)} + c.Error(err) return } c.Next() diff --git a/backend/internal/middleware/jwt_auth.go b/backend/internal/middleware/jwt_auth.go index 9416d5a..36be4cc 100644 --- a/backend/internal/middleware/jwt_auth.go +++ b/backend/internal/middleware/jwt_auth.go @@ -2,9 +2,8 @@ package middleware import ( "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" - "net/http" "strings" ) @@ -29,8 +28,7 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { c.Next() return } else { - utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in") - c.Abort() + c.Error(&common.NotSignedInError{}) return } } @@ -40,14 +38,14 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { c.Next() return } else if err != nil { - utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in") + c.Error(&common.NotSignedInError{}) c.Abort() return } // Check if the user is an admin if adminOnly && !claims.IsAdmin { - utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource") + c.Error(&common.MissingPermissionError{}) c.Abort() return } diff --git a/backend/internal/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go index 494ee06..f9686a6 100644 --- a/backend/internal/middleware/rate_limit.go +++ b/backend/internal/middleware/rate_limit.go @@ -2,8 +2,6 @@ package middleware import ( "github.com/stonith404/pocket-id/backend/internal/common" - "github.com/stonith404/pocket-id/backend/internal/utils" - "net/http" "sync" "time" @@ -33,8 +31,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc { limiter := getLimiter(ip, limit, burst) if !limiter.Allow() { - utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.") - c.Abort() + c.Error(&common.TooManyRequestsError{}) return } diff --git a/backend/internal/model/custom_claim.go b/backend/internal/model/custom_claim.go new file mode 100644 index 0000000..c47f934 --- /dev/null +++ b/backend/internal/model/custom_claim.go @@ -0,0 +1,11 @@ +package model + +type CustomClaim struct { + Base + + Key string + Value string + + UserID *string + UserGroupID *string +} diff --git a/backend/internal/model/user.go b/backend/internal/model/user.go index 8cb6f0b..137838c 100644 --- a/backend/internal/model/user.go +++ b/backend/internal/model/user.go @@ -15,8 +15,9 @@ type User struct { LastName string IsAdmin bool - UserGroups []UserGroup `gorm:"many2many:user_groups_users;"` - Credentials []WebauthnCredential + CustomClaims []CustomClaim + UserGroups []UserGroup `gorm:"many2many:user_groups_users;"` + Credentials []WebauthnCredential } func (u User) WebAuthnID() []byte { return []byte(u.ID) } diff --git a/backend/internal/model/user_group.go b/backend/internal/model/user_group.go index 8559016..15648d4 100644 --- a/backend/internal/model/user_group.go +++ b/backend/internal/model/user_group.go @@ -5,4 +5,5 @@ type UserGroup struct { FriendlyName string Name string `gorm:"unique"` Users []User `gorm:"many2many:user_groups_users;"` + CustomClaims []CustomClaim } diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index 61b17fe..08f07fd 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -165,7 +165,7 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image fileType := utils.GetFileExtension(uploadedFile.Filename) mimeType := utils.GetImageMimeType(fileType) if mimeType == "" { - return common.ErrFileTypeNotSupported + return &common.FileTypeNotSupportedError{} } // Delete the old image if it has a different file type diff --git a/backend/internal/service/custom_claim_service.go b/backend/internal/service/custom_claim_service.go new file mode 100644 index 0000000..5521999 --- /dev/null +++ b/backend/internal/service/custom_claim_service.go @@ -0,0 +1,197 @@ +package service + +import ( + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" + "github.com/stonith404/pocket-id/backend/internal/model" + "gorm.io/gorm" +) + +// Reserved claims +var reservedClaims = map[string]struct{}{ + "given_name": {}, + "family_name": {}, + "name": {}, + "email": {}, + "preferred_username": {}, + "groups": {}, + "sub": {}, + "iss": {}, + "aud": {}, + "exp": {}, + "iat": {}, + "auth_time": {}, + "nonce": {}, + "acr": {}, + "amr": {}, + "azp": {}, + "nbf": {}, + "jti": {}, +} + +type CustomClaimService struct { + db *gorm.DB +} + +func NewCustomClaimService(db *gorm.DB) *CustomClaimService { + return &CustomClaimService{db: db} +} + +// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username +func isReservedClaim(key string) bool { + _, ok := reservedClaims[key] + return ok +} + +// idType is the type of the id used to identify the user or user group +type idType string + +const ( + UserID idType = "user_id" + UserGroupID idType = "user_group_id" +) + +// UpdateCustomClaimsForUser updates the custom claims for a user +func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + return s.updateCustomClaims(UserID, userID, claims) +} + +// UpdateCustomClaimsForUserGroup updates the custom claims for a user group +func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + return s.updateCustomClaims(UserGroupID, userGroupID, claims) +} + +// updateCustomClaims updates the custom claims for a user or user group +func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + // Check for duplicate keys in the claims slice + seenKeys := make(map[string]bool) + for _, claim := range claims { + if seenKeys[claim.Key] { + return nil, &common.DuplicateClaimError{Key: claim.Key} + } + seenKeys[claim.Key] = true + } + + var existingClaims []model.CustomClaim + err := s.db.Where(string(idType), value).Find(&existingClaims).Error + if err != nil { + return nil, err + } + + // Delete claims that are not in the new list + for _, existingClaim := range existingClaims { + found := false + for _, claim := range claims { + if claim.Key == existingClaim.Key { + found = true + break + } + } + if !found { + err = s.db.Delete(&existingClaim).Error + if err != nil { + return nil, err + } + } + } + + // Add or update claims + for _, claim := range claims { + if isReservedClaim(claim.Key) { + return nil, &common.ReservedClaimError{Key: claim.Key} + } + customClaim := model.CustomClaim{ + Key: claim.Key, + Value: claim.Value, + } + + if idType == UserID { + customClaim.UserID = &value + } else if idType == UserGroupID { + customClaim.UserGroupID = &value + } + + // Update the claim if it already exists or create a new one + err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error + if err != nil { + return nil, err + } + } + + // Get the updated claims + var updatedClaims []model.CustomClaim + err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error + if err != nil { + return nil, err + } + + return updatedClaims, nil +} + +func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) { + var customClaims []model.CustomClaim + err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error + return customClaims, err +} + +func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) { + var customClaims []model.CustomClaim + err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error + return customClaims, err +} + +// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of, +// prioritizing the user's claims over user group claims with the same key. +func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) { + // Get the custom claims of the user + customClaims, err := s.GetCustomClaimsForUser(userID) + if err != nil { + return nil, err + } + + // Store user's claims in a map to prioritize and prevent duplicates + claimsMap := make(map[string]model.CustomClaim) + for _, claim := range customClaims { + claimsMap[claim.Key] = claim + } + + // Get all user groups of the user + var userGroupsOfUser []model.UserGroup + err = s.db.Preload("CustomClaims"). + Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id"). + Where("user_groups_users.user_id = ?", userID). + Find(&userGroupsOfUser).Error + if err != nil { + return nil, err + } + + // Add only non-duplicate custom claims from user groups + for _, userGroup := range userGroupsOfUser { + for _, groupClaim := range userGroup.CustomClaims { + // Only add claim if it does not exist in the user's claims + if _, exists := claimsMap[groupClaim.Key]; !exists { + claimsMap[groupClaim.Key] = groupClaim + } + } + } + + // Convert the claimsMap back to a slice + finalClaims := make([]model.CustomClaim, 0, len(claimsMap)) + for _, claim := range claimsMap { + finalClaims = append(finalClaims, claim) + } + + return finalClaims, nil +} + +// GetSuggestions returns a list of custom claim keys that have been used before +func (s *CustomClaimService) GetSuggestions() ([]string, error) { + var customClaimsKeys []string + + err := s.db.Model(&model.CustomClaim{}). + Group("key"). + Order("COUNT(*) DESC"). + Pluck("key", &customClaimsKeys).Error + + return customClaimsKeys, err +} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 9d24bcd..9036406 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -18,18 +18,20 @@ import ( ) type OidcService struct { - db *gorm.DB - jwtService *JwtService - appConfigService *AppConfigService - auditLogService *AuditLogService + db *gorm.DB + jwtService *JwtService + appConfigService *AppConfigService + auditLogService *AuditLogService + customClaimService *CustomClaimService } -func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService) *OidcService { +func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService { return &OidcService{ - db: db, - jwtService: jwtService, - appConfigService: appConfigService, - auditLogService: auditLogService, + db: db, + jwtService: jwtService, + appConfigService: appConfigService, + auditLogService: auditLogService, + customClaimService: customClaimService, } } @@ -38,7 +40,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID) if userAuthorizedOIDCClient.Scope != input.Scope { - return "", "", common.ErrOidcMissingAuthorization + return "", "", &common.OidcMissingAuthorizationError{} } callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL) @@ -93,11 +95,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) { if grantType != "authorization_code" { - return "", "", common.ErrOidcGrantTypeNotSupported + return "", "", &common.OidcGrantTypeNotSupportedError{} } if clientID == "" || clientSecret == "" { - return "", "", common.ErrOidcMissingClientCredentials + return "", "", &common.OidcMissingClientCredentialsError{} } var client model.OidcClient @@ -107,17 +109,17 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) if err != nil { - return "", "", common.ErrOidcClientSecretInvalid + return "", "", &common.OidcClientSecretInvalidError{} } var authorizationCodeMetaData model.OidcAuthorizationCode err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error if err != nil { - return "", "", common.ErrOidcInvalidAuthorizationCode + return "", "", &common.OidcInvalidAuthorizationCodeError{} } if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { - return "", "", common.ErrOidcInvalidAuthorizationCode + return "", "", &common.OidcInvalidAuthorizationCodeError{} } userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) @@ -249,7 +251,7 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) { func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error { fileType := utils.GetFileExtension(file.Filename) if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { - return common.ErrFileTypeNotSupported + return &common.FileTypeNotSupportedError{} } imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType) @@ -334,9 +336,20 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma } if strings.Contains(scope, "profile") { + // Add profile claims for k, v := range profileClaims { claims[k] = v } + + // Add custom claims + customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID) + if err != nil { + return nil, err + } + + for _, customClaim := range customClaims { + claims[customClaim.Key] = customClaim.Value + } } if strings.Contains(scope, "email") { claims["email"] = user.Email @@ -375,5 +388,5 @@ func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackU return inputCallbackURL, nil } - return "", common.ErrOidcInvalidCallbackURL + return "", &common.OidcInvalidCallbackURLError{} } diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index 6dbd9ad..065f9e7 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -18,7 +18,7 @@ func NewUserGroupService(db *gorm.DB) *UserGroupService { } func (s *UserGroupService) List(name string, page int, pageSize int) (groups []model.UserGroup, response utils.PaginationResponse, err error) { - query := s.db.Model(&model.UserGroup{}) + query := s.db.Preload("CustomClaims").Model(&model.UserGroup{}) if name != "" { query = query.Where("name LIKE ?", "%"+name+"%") @@ -29,7 +29,7 @@ func (s *UserGroupService) List(name string, page int, pageSize int) (groups []m } func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) { - err = s.db.Where("id = ?", id).Preload("Users").First(&group).Error + err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error return group, err } @@ -50,7 +50,7 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use if err := s.db.Preload("Users").Create(&group).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return model.UserGroup{}, common.ErrNameAlreadyInUse + return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} } return model.UserGroup{}, err } @@ -68,7 +68,7 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto) (grou if err := s.db.Preload("Users").Save(&group).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return model.UserGroup{}, common.ErrNameAlreadyInUse + return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} } return model.UserGroup{}, err } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 0c94a3b..e00f5fe 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -35,7 +35,7 @@ func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]mo func (s *UserService) GetUser(userID string) (model.User, error) { var user model.User - err := s.db.Where("id = ?", userID).First(&user).Error + err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error return user, err } @@ -111,7 +111,7 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, stri var oneTimeAccessToken model.OneTimeAccessToken if err := s.db.Where("token = ? AND expires_at > ?", token, time.Now().Unix()).Preload("User").First(&oneTimeAccessToken).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return model.User{}, "", common.ErrTokenInvalidOrExpired + return model.User{}, "", &common.TokenInvalidOrExpiredError{} } return model.User{}, "", err } @@ -133,7 +133,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { return model.User{}, "", err } if userCount > 1 { - return model.User{}, "", common.ErrSetupAlreadyCompleted + return model.User{}, "", &common.SetupAlreadyCompletedError{} } user := model.User{ @@ -149,7 +149,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { } if len(user.Credentials) > 0 { - return model.User{}, "", common.ErrSetupAlreadyCompleted + return model.User{}, "", &common.SetupAlreadyCompletedError{} } token, err := s.jwtService.GenerateAccessToken(user) @@ -163,11 +163,11 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { func (s *UserService) checkDuplicatedFields(user model.User) error { var existingUser model.User if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { - return common.ErrEmailTaken + return &common.AlreadyInUseError{Property: "email"} } if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { - return common.ErrUsernameTaken + return &common.AlreadyInUseError{Property: "username"} } return nil diff --git a/backend/migrations/20241028064959_custom_claims.down.sql b/backend/migrations/20241028064959_custom_claims.down.sql new file mode 100644 index 0000000..113dcb5 --- /dev/null +++ b/backend/migrations/20241028064959_custom_claims.down.sql @@ -0,0 +1 @@ +DROP TABLE custom_claims; \ No newline at end of file diff --git a/backend/migrations/20241028064959_custom_claims.up.sql b/backend/migrations/20241028064959_custom_claims.up.sql new file mode 100644 index 0000000..9d6acc6 --- /dev/null +++ b/backend/migrations/20241028064959_custom_claims.up.sql @@ -0,0 +1,15 @@ +CREATE TABLE custom_claims +( + id TEXT NOT NULL PRIMARY KEY, + created_at DATETIME, + key TEXT NOT NULL, + value TEXT NOT NULL, + + user_id TEXT, + user_group_id TEXT, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE, + FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE, + + CONSTRAINT custom_claims_unique UNIQUE (key, user_id, user_group_id), + CHECK (user_id IS NOT NULL OR user_group_id IS NOT NULL) +); \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 6feb302..df274b5 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1,12 +1,12 @@ { "name": "pocket-id-frontend", - "version": "0.9.0", + "version": "0.10.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "pocket-id-frontend", - "version": "0.9.0", + "version": "0.10.0", "dependencies": { "@simplewebauthn/browser": "^10.0.0", "axios": "^1.7.7", diff --git a/frontend/src/lib/components/auto-complete-input.svelte b/frontend/src/lib/components/auto-complete-input.svelte new file mode 100644 index 0000000..da5e3c5 --- /dev/null +++ b/frontend/src/lib/components/auto-complete-input.svelte @@ -0,0 +1,105 @@ + + +
+ (isInputFocused = true)} + onblur={() => (isInputFocused = false)} + /> + {}} + closeOnOutsideClick={false} + closeOnEscape={false} + > + +
diff --git a/frontend/src/lib/components/custom-claims-input.svelte b/frontend/src/lib/components/custom-claims-input.svelte new file mode 100644 index 0000000..253d56d --- /dev/null +++ b/frontend/src/lib/components/custom-claims-input.svelte @@ -0,0 +1,75 @@ + + +
+ +
+ {#each customClaims as _, i} +
+ + + +
+ {/each} +
+
+ {#if error} +

{error}

+ {/if} + {#if customClaims.length < limit} + + {/if} +
diff --git a/frontend/src/lib/components/form-input.svelte b/frontend/src/lib/components/form-input.svelte index a1021ad..7aeb578 100644 --- a/frontend/src/lib/components/form-input.svelte +++ b/frontend/src/lib/components/form-input.svelte @@ -16,7 +16,7 @@ ...restProps }: HTMLAttributes & { input?: FormInput; - label: string; + label?: string; description?: string; disabled?: boolean; type?: 'text' | 'password' | 'email' | 'number' | 'checkbox'; @@ -24,15 +24,17 @@ children?: Snippet; } = $props(); - const id = label.toLowerCase().replace(/ /g, '-'); + const id = label?.toLowerCase().replace(/ /g, '-');
- + {#if label} + + {/if} {#if description}

{description}

{/if} -
+
{#if children} {@render children()} {:else if input} diff --git a/frontend/src/lib/components/ui/popover/index.ts b/frontend/src/lib/components/ui/popover/index.ts new file mode 100644 index 0000000..63aecf9 --- /dev/null +++ b/frontend/src/lib/components/ui/popover/index.ts @@ -0,0 +1,17 @@ +import { Popover as PopoverPrimitive } from "bits-ui"; +import Content from "./popover-content.svelte"; +const Root = PopoverPrimitive.Root; +const Trigger = PopoverPrimitive.Trigger; +const Close = PopoverPrimitive.Close; + +export { + Root, + Content, + Trigger, + Close, + // + Root as Popover, + Content as PopoverContent, + Trigger as PopoverTrigger, + Close as PopoverClose, +}; diff --git a/frontend/src/lib/components/ui/popover/popover-content.svelte b/frontend/src/lib/components/ui/popover/popover-content.svelte new file mode 100644 index 0000000..5bad4b7 --- /dev/null +++ b/frontend/src/lib/components/ui/popover/popover-content.svelte @@ -0,0 +1,22 @@ + + + + + diff --git a/frontend/src/lib/services/custom-claim-service.ts b/frontend/src/lib/services/custom-claim-service.ts new file mode 100644 index 0000000..87f58f4 --- /dev/null +++ b/frontend/src/lib/services/custom-claim-service.ts @@ -0,0 +1,19 @@ +import type { CustomClaim } from '$lib/types/custom-claim.type'; +import APIService from './api-service'; + +export default class CustomClaimService extends APIService { + async getSuggestions() { + const res = await this.api.get('/custom-claims/suggestions'); + return res.data as string[]; + } + + async updateUserCustomClaims(userId: string, claims: CustomClaim[]) { + const res = await this.api.put(`/custom-claims/user/${userId}`, claims); + return res.data as CustomClaim[]; + } + + async updateUserGroupCustomClaims(userGroupId: string, claims: CustomClaim[]) { + const res = await this.api.put(`/custom-claims/user-group/${userGroupId}`, claims); + return res.data as CustomClaim[]; + } +} diff --git a/frontend/src/lib/types/custom-claim.type.ts b/frontend/src/lib/types/custom-claim.type.ts new file mode 100644 index 0000000..24a3659 --- /dev/null +++ b/frontend/src/lib/types/custom-claim.type.ts @@ -0,0 +1,4 @@ +export type CustomClaim = { + key: string; + value: string; +}; diff --git a/frontend/src/lib/types/user-group.type.ts b/frontend/src/lib/types/user-group.type.ts index b86f4af..da02635 100644 --- a/frontend/src/lib/types/user-group.type.ts +++ b/frontend/src/lib/types/user-group.type.ts @@ -1,3 +1,4 @@ +import type { CustomClaim } from './custom-claim.type'; import type { User } from './user.type'; export type UserGroup = { @@ -5,6 +6,7 @@ export type UserGroup = { friendlyName: string; name: string; createdAt: string; + customClaims: CustomClaim[]; }; export type UserGroupWithUsers = UserGroup & { diff --git a/frontend/src/lib/types/user.type.ts b/frontend/src/lib/types/user.type.ts index f22368e..1157ecf 100644 --- a/frontend/src/lib/types/user.type.ts +++ b/frontend/src/lib/types/user.type.ts @@ -1,3 +1,5 @@ +import type { CustomClaim } from './custom-claim.type'; + export type User = { id: string; username: string; @@ -5,6 +7,7 @@ export type User = { firstName: string; lastName: string; isAdmin: boolean; + customClaims: CustomClaim[]; }; -export type UserCreate = Omit; +export type UserCreate = Omit; diff --git a/frontend/src/routes/settings/admin/user-groups/[id]/+page.svelte b/frontend/src/routes/settings/admin/user-groups/[id]/+page.svelte index 0d4fb4e..a2abcb6 100644 --- a/frontend/src/routes/settings/admin/user-groups/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/user-groups/[id]/+page.svelte @@ -1,6 +1,8 @@ @@ -53,7 +65,7 @@
- Meta data + General @@ -76,3 +88,20 @@
+ + + + Custom Claims + + Custom claims are key-value pairs that can be used to store additional information about a + user. These claims will be included in the ID token if the scope "profile" is requested. + Custom claims defined on the user will be prioritized if there are conflicts. + + + + +
+ +
+
+
diff --git a/frontend/src/routes/settings/admin/users/[id]/+page.svelte b/frontend/src/routes/settings/admin/users/[id]/+page.svelte index fc99a66..657ae91 100644 --- a/frontend/src/routes/settings/admin/users/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/users/[id]/+page.svelte @@ -1,16 +1,20 @@ @@ -37,10 +50,25 @@
- {user.firstName} {user.lastName} + General - + + + + Custom Claims + + Custom claims are key-value pairs that can be used to store additional information about a + user. These claims will be included in the ID token if the scope "profile" is requested. + + + + +
+ +
+
+
diff --git a/frontend/tests/account-settings.spec.ts b/frontend/tests/account-settings.spec.ts index 96dcb1c..eea7551 100644 --- a/frontend/tests/account-settings.spec.ts +++ b/frontend/tests/account-settings.spec.ts @@ -24,7 +24,7 @@ test('Update account details fails with already taken email', async ({ page }) = await page.getByRole('button', { name: 'Save' }).click(); - await expect(page.getByRole('status')).toHaveText('Email is already taken'); + await expect(page.getByRole('status')).toHaveText('Email is already in use'); }); test('Update account details fails with already taken username', async ({ page }) => { @@ -34,7 +34,7 @@ test('Update account details fails with already taken username', async ({ page } await page.getByRole('button', { name: 'Save' }).click(); - await expect(page.getByRole('status')).toHaveText('Username is already taken'); + await expect(page.getByRole('status')).toHaveText('Username is already in use'); }); test('Add passkey to an account', async ({ page }) => { diff --git a/frontend/tests/user-group.spec.ts b/frontend/tests/user-group.spec.ts index bfbbd4c..b07a1db 100644 --- a/frontend/tests/user-group.spec.ts +++ b/frontend/tests/user-group.spec.ts @@ -73,3 +73,39 @@ test('Delete user group', async ({ page }) => { await expect(page.getByRole('status')).toHaveText('User group deleted successfully'); await expect(page.getByRole('row', { name: group.name })).not.toBeVisible(); }); + +test('Update user group custom claims', async ({ page }) => { + await page.goto(`/settings/admin/user-groups/${userGroups.designers.id}`); + + // Add two custom claims + await page.getByRole('button', { name: 'Add custom claim' }).click(); + + await page.getByPlaceholder('Key').fill('custom_claim_1'); + await page.getByPlaceholder('Value').fill('custom_claim_1_value'); + + await page.getByRole('button', { name: 'Add another' }).click(); + await page.getByPlaceholder('Key').nth(1).fill('custom_claim_2'); + await page.getByPlaceholder('Value').nth(1).fill('custom_claim_2_value'); + + await page.getByRole('button', { name: 'Save' }).nth(2).click(); + + await expect(page.getByRole('status')).toHaveText('Custom claims updated successfully'); + + await page.reload(); + + // Check if custom claims are saved + await expect(page.getByPlaceholder('Key').first()).toHaveValue('custom_claim_1'); + await expect(page.getByPlaceholder('Value').first()).toHaveValue('custom_claim_1_value'); + await expect(page.getByPlaceholder('Key').nth(1)).toHaveValue('custom_claim_2'); + await expect(page.getByPlaceholder('Value').nth(1)).toHaveValue('custom_claim_2_value'); + + // Remove one custom claim + await page.getByLabel('Remove custom claim').first().click(); + await page.getByRole('button', { name: 'Save' }).nth(2).click(); + + await page.reload(); + + // Check if custom claim is removed + await expect(page.getByPlaceholder('Key').first()).toHaveValue('custom_claim_2'); + await expect(page.getByPlaceholder('Value').first()).toHaveValue('custom_claim_2_value'); +}); diff --git a/frontend/tests/user-settings.spec.ts b/frontend/tests/user-settings.spec.ts index 0047373..7c67190 100644 --- a/frontend/tests/user-settings.spec.ts +++ b/frontend/tests/user-settings.spec.ts @@ -32,7 +32,7 @@ test('Create user fails with already taken email', async ({ page }) => { await page.getByLabel('Username').fill(user.username); await page.getByRole('button', { name: 'Save' }).click(); - await expect(page.getByRole('status')).toHaveText('Email is already taken'); + await expect(page.getByRole('status')).toHaveText('Email is already in use'); }); test('Create user fails with already taken username', async ({ page }) => { @@ -47,7 +47,7 @@ test('Create user fails with already taken username', async ({ page }) => { await page.getByLabel('Username').fill(users.tim.username); await page.getByRole('button', { name: 'Save' }).click(); - await expect(page.getByRole('status')).toHaveText('Username is already taken'); + await expect(page.getByRole('status')).toHaveText('Username is already in use'); }); test('Create one time access token', async ({ page }) => { @@ -95,7 +95,7 @@ test('Update user', async ({ page }) => { await page.getByLabel('Last name').fill('Apple'); await page.getByLabel('Email').fill('crack.apple@test.com'); await page.getByLabel('Username').fill('crack'); - await page.getByRole('button', { name: 'Save' }).click(); + await page.getByRole('button', { name: 'Save' }).first().click(); await expect(page.getByRole('status')).toHaveText('User updated successfully'); }); @@ -112,9 +112,9 @@ test('Update user fails with already taken email', async ({ page }) => { await page.getByRole('menuitem', { name: 'Edit' }).click(); await page.getByLabel('Email').fill(users.tim.email); - await page.getByRole('button', { name: 'Save' }).click(); + await page.getByRole('button', { name: 'Save' }).first().click(); - await expect(page.getByRole('status')).toHaveText('Email is already taken'); + await expect(page.getByRole('status')).toHaveText('Email is already in use'); }); test('Update user fails with already taken username', async ({ page }) => { @@ -129,7 +129,43 @@ test('Update user fails with already taken username', async ({ page }) => { await page.getByRole('menuitem', { name: 'Edit' }).click(); await page.getByLabel('Username').fill(users.tim.username); - await page.getByRole('button', { name: 'Save' }).click(); + await page.getByRole('button', { name: 'Save' }).first().click(); + + await expect(page.getByRole('status')).toHaveText('Username is already in use'); +}); + +test('Update user custom claims', async ({ page }) => { + await page.goto(`/settings/admin/users/${users.craig.id}`); + + // Add two custom claims + await page.getByRole('button', { name: 'Add custom claim' }).click(); + + await page.getByPlaceholder('Key').fill('custom_claim_1'); + await page.getByPlaceholder('Value').fill('custom_claim_1_value'); + + await page.getByRole('button', { name: 'Add another' }).click(); + await page.getByPlaceholder('Key').nth(1).fill('custom_claim_2'); + await page.getByPlaceholder('Value').nth(1).fill('custom_claim_2_value'); + + await page.getByRole('button', { name: 'Save' }).nth(1).click(); + + await expect(page.getByRole('status')).toHaveText('Custom claims updated successfully'); + + await page.reload(); + + // Check if custom claims are saved + await expect(page.getByPlaceholder('Key').first()).toHaveValue('custom_claim_1'); + await expect(page.getByPlaceholder('Value').first()).toHaveValue('custom_claim_1_value'); + await expect(page.getByPlaceholder('Key').nth(1)).toHaveValue('custom_claim_2'); + await expect(page.getByPlaceholder('Value').nth(1)).toHaveValue('custom_claim_2_value'); + + // Remove one custom claim + await page.getByLabel('Remove custom claim').first().click(); + await page.getByRole('button', { name: 'Save' }).nth(1).click(); + + await page.reload(); - await expect(page.getByRole('status')).toHaveText('Username is already taken'); + // Check if custom claim is removed + await expect(page.getByPlaceholder('Key').first()).toHaveValue('custom_claim_2'); + await expect(page.getByPlaceholder('Value').first()).toHaveValue('custom_claim_2_value'); });