From 465b512b6b2e9dbb2ae06d85d1dae83408a0a734 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Fri, 25 Oct 2024 17:02:16 +0200 Subject: [PATCH 1/4] add first version of custom claims --- .../internal/bootstrap/router_bootstrap.go | 5 +- backend/internal/common/errors.go | 149 ++++++++++++++++-- .../controller/app_config_controller.go | 27 ++-- .../controller/audit_log_controller.go | 5 +- .../controller/custom_claim_controller.go | 53 +++++++ .../internal/controller/oidc_controller.go | 73 +++------ .../internal/controller/test_controller.go | 7 +- .../internal/controller/user_controller.go | 57 +++---- .../controller/user_group_controller.go | 41 ++--- .../controller/webauthn_controller.go | 38 ++--- .../controller/well_known_controller.go | 3 +- backend/internal/dto/custom_claim_dto.go | 8 + backend/internal/dto/user_dto.go | 15 +- .../error_handler.go} | 74 ++++++--- .../internal/middleware/file_size_limit.go | 6 +- backend/internal/middleware/jwt_auth.go | 10 +- backend/internal/middleware/rate_limit.go | 5 +- backend/internal/model/custom_claim.go | 10 ++ backend/internal/model/user.go | 5 +- .../internal/service/app_config_service.go | 2 +- .../internal/service/custom_claim_service.go | 111 +++++++++++++ backend/internal/service/oidc_service.go | 47 ++++-- .../internal/service/user_group_service.go | 4 +- backend/internal/service/user_service.go | 12 +- .../20241024064959_custom_claims.down.sql | 1 + .../20241024064959_custom_claims.up.sql | 12 ++ frontend/package-lock.json | 4 +- frontend/src/lib/components/form-input.svelte | 10 +- .../src/lib/components/ui/popover/index.ts | 17 ++ .../ui/popover/popover-content.svelte | 22 +++ .../src/lib/services/custom-claim-service.ts | 15 ++ frontend/src/lib/types/custom-claim.type.ts | 4 + frontend/src/lib/types/user.type.ts | 5 +- .../settings/admin/users/[id]/+page.svelte | 32 +++- .../users/[id]/auto-complete-input.svelte | 97 ++++++++++++ .../users/[id]/custom-claim-input.svelte | 78 +++++++++ 36 files changed, 807 insertions(+), 257 deletions(-) create mode 100644 backend/internal/controller/custom_claim_controller.go create mode 100644 backend/internal/dto/custom_claim_dto.go rename backend/internal/{utils/controller_error_util.go => middleware/error_handler.go} (51%) create mode 100644 backend/internal/model/custom_claim.go create mode 100644 backend/internal/service/custom_claim_service.go create mode 100644 backend/migrations/20241024064959_custom_claims.down.sql create mode 100644 backend/migrations/20241024064959_custom_claims.up.sql create mode 100644 frontend/src/lib/components/ui/popover/index.ts create mode 100644 frontend/src/lib/components/ui/popover/popover-content.svelte create mode 100644 frontend/src/lib/services/custom-claim-service.ts create mode 100644 frontend/src/lib/types/custom-claim.type.ts create mode 100644 frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte create mode 100644 frontend/src/routes/settings/admin/users/[id]/custom-claim-input.svelte diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 57d1fda..1a07a39 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -39,11 +39,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { jwtService := service.NewJwtService(appConfigService) webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) userService := service.NewUserService(db, jwtService) - oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService) + customClaimService := service.NewCustomClaimService(db) + oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) testService := service.NewTestService(db, appConfigService) userGroupService := service.NewUserGroupService(db) r.Use(middleware.NewCorsMiddleware().Add()) + r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60)) r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false)) @@ -59,6 +61,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService) controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware) controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService) + controller.NewCustomClaimController(apiGroup, jwtAuthMiddleware, customClaimService) // Add test controller in non-production environments if common.EnvConfig.AppEnv != "production" { diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index d05ac13..6cb79a5 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -1,19 +1,136 @@ package common -import "errors" - -var ( - ErrUsernameTaken = errors.New("username is already taken") - ErrEmailTaken = errors.New("email is already taken") - ErrSetupAlreadyCompleted = errors.New("setup already completed") - ErrTokenInvalidOrExpired = errors.New("token is invalid or expired") - ErrOidcMissingAuthorization = errors.New("missing authorization") - ErrOidcGrantTypeNotSupported = errors.New("grant type not supported") - ErrOidcMissingClientCredentials = errors.New("client id or secret not provided") - ErrOidcClientSecretInvalid = errors.New("invalid client secret") - ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code") - ErrOidcInvalidCallbackURL = errors.New("invalid callback URL") - ErrFileTypeNotSupported = errors.New("file type not supported") - ErrInvalidCredentials = errors.New("no user found with provided credentials") - ErrNameAlreadyInUse = errors.New("name is already in use") +import ( + "fmt" + "net/http" ) + +type AppError interface { + error + HttpStatusCode() int +} + +// Custom error types for various conditions + +type AlreadyInUseError struct { + Property string +} + +func (e *AlreadyInUseError) Error() string { + return fmt.Sprintf("%s is already in use", e.Property) +} +func (e *AlreadyInUseError) HttpStatusCode() int { return 400 } + +type SetupAlreadyCompletedError struct{} + +func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" } +func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return 400 } + +type TokenInvalidOrExpiredError struct{} + +func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" } +func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 } + +type OidcMissingAuthorizationError struct{} + +func (e *OidcMissingAuthorizationError) Error() string { return "missing authorization" } +func (e *OidcMissingAuthorizationError) HttpStatusCode() int { return http.StatusForbidden } + +type OidcGrantTypeNotSupportedError struct{} + +func (e *OidcGrantTypeNotSupportedError) Error() string { return "grant type not supported" } +func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return 400 } + +type OidcMissingClientCredentialsError struct{} + +func (e *OidcMissingClientCredentialsError) Error() string { return "client id or secret not provided" } +func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return 400 } + +type OidcClientSecretInvalidError struct{} + +func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" } +func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 } + +type OidcInvalidAuthorizationCodeError struct{} + +func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" } +func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 } + +type OidcInvalidCallbackURLError struct{} + +func (e *OidcInvalidCallbackURLError) Error() string { return "invalid callback URL" } +func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 } + +type FileTypeNotSupportedError struct{} + +func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" } +func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 } + +type InvalidCredentialsError struct{} + +func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" } +func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 } + +type FileTooLargeError struct { + MaxSize string +} + +func (e *FileTooLargeError) Error() string { + return fmt.Sprintf("The file can't be larger than %s", e.MaxSize) +} +func (e *FileTooLargeError) HttpStatusCode() int { return http.StatusRequestEntityTooLarge } + +type NotSignedInError struct{} + +func (e *NotSignedInError) Error() string { return "You are not signed in" } +func (e *NotSignedInError) HttpStatusCode() int { return http.StatusUnauthorized } + +type MissingPermissionError struct{} + +func (e *MissingPermissionError) Error() string { + return "You don't have permission to perform this action" +} +func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbidden } + +type TooManyRequestsError struct{} + +func (e *TooManyRequestsError) Error() string { + return "Too many requests. Please wait a while before trying again." +} +func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests } + +type ClientIdOrSecretNotProvidedError struct{} + +func (e *ClientIdOrSecretNotProvidedError) Error() string { + return "Client id and secret not provided" +} + +func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest } + +type WrongFileTypeError struct { + ExpectedFileType string +} + +func (e *WrongFileTypeError) Error() string { + return fmt.Sprintf("File must be of type %s", e.ExpectedFileType) +} + +func (e *WrongFileTypeError) HttpStatusCode() int { return http.StatusBadRequest } + +type MissingSessionIdError struct{} + +func (e *MissingSessionIdError) Error() string { + return "Missing session id" +} + +func (e *MissingSessionIdError) HttpStatusCode() int { return http.StatusBadRequest } + +type ReservedClaimError struct { + Key string +} + +func (e *ReservedClaimError) Error() string { + return fmt.Sprintf("Claim %s is reserved and can't be used", e.Key) +} + +func (e *ReservedClaimError) HttpStatusCode() int { return http.StatusBadRequest } diff --git a/backend/internal/controller/app_config_controller.go b/backend/internal/controller/app_config_controller.go index c6ee106..5bed631 100644 --- a/backend/internal/controller/app_config_controller.go +++ b/backend/internal/controller/app_config_controller.go @@ -1,7 +1,6 @@ package controller import ( - "errors" "fmt" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" @@ -39,13 +38,13 @@ type AppConfigController struct { func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { configuration, err := acc.appConfigService.ListAppConfig(false) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var configVariablesDto []dto.PublicAppConfigVariableDto if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -55,13 +54,13 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { configuration, err := acc.appConfigService.ListAppConfig(true) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var configVariablesDto []dto.AppConfigVariableDto if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -71,19 +70,19 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) { var input dto.AppConfigUpdateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var configVariablesDto []dto.AppConfigVariableDto if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -136,13 +135,13 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) { func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) { file, err := c.FormFile("file") if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } fileType := utils.GetFileExtension(file.Filename) if fileType != "ico" { - utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico") + c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"}) return } acc.updateImage(c, "favicon", "ico") @@ -164,17 +163,13 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) { file, err := c.FormFile("file") if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) if err != nil { - if errors.Is(err, common.ErrFileTypeNotSupported) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } diff --git a/backend/internal/controller/audit_log_controller.go b/backend/internal/controller/audit_log_controller.go index e6e1616..a032645 100644 --- a/backend/internal/controller/audit_log_controller.go +++ b/backend/internal/controller/audit_log_controller.go @@ -8,7 +8,6 @@ import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" ) func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.AuditLogService, jwtAuthMiddleware *middleware.JwtAuthMiddleware) { @@ -31,7 +30,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { // Fetch audit logs for the user logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -39,7 +38,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { var logsDtos []dto.AuditLogDto err = dto.MapStructList(logs, &logsDtos) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/custom_claim_controller.go b/backend/internal/controller/custom_claim_controller.go new file mode 100644 index 0000000..57ed4f8 --- /dev/null +++ b/backend/internal/controller/custom_claim_controller.go @@ -0,0 +1,53 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/dto" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/service" + "net/http" +) + +func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) { + wkc := &CustomClaimController{customClaimService: customClaimService} + group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler) + group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.updateUserCustomClaimsHandler) +} + +type CustomClaimController struct { + customClaimService *service.CustomClaimService +} + +func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) { + claims, err := ccc.customClaimService.GetSuggestions() + if err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, claims) +} + +func (ccc *CustomClaimController) updateUserCustomClaimsHandler(c *gin.Context) { + var input []dto.CustomClaimCreateDto + + if err := c.ShouldBindJSON(&input); err != nil { + c.Error(err) + return + } + + userId := c.Param("userId") + claims, err := ccc.customClaimService.UpdateUserCustomClaims(userId, input) + if err != nil { + c.Error(err) + return + } + + var customClaimsDto []dto.CustomClaimDto + if err := dto.MapStructList(claims, &customClaimsDto); err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, customClaimsDto) +} diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 8982c8a..49934cb 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -1,13 +1,11 @@ package controller import ( - "errors" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" "strconv" "strings" @@ -42,19 +40,13 @@ type OidcController struct { func (oc *OidcController) authorizeHandler(c *gin.Context) { var input dto.AuthorizeOidcClientRequestDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) if err != nil { - if errors.Is(err, common.ErrOidcMissingAuthorization) { - utils.CustomControllerError(c, http.StatusForbidden, err.Error()) - } else if errors.Is(err, common.ErrOidcInvalidCallbackURL) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -69,17 +61,13 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { var input dto.AuthorizeOidcClientRequestDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) if err != nil { - if errors.Is(err, common.ErrOidcInvalidCallbackURL) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -95,7 +83,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { var input dto.OidcIdTokenDto if err := c.ShouldBind(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -107,21 +95,14 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { var ok bool clientID, clientSecret, ok = c.Request.BasicAuth() if !ok { - utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided") + c.Error(&common.ClientIdOrSecretNotProvidedError{}) return } } idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret) if err != nil { - if errors.Is(err, common.ErrOidcGrantTypeNotSupported) || - errors.Is(err, common.ErrOidcMissingClientCredentials) || - errors.Is(err, common.ErrOidcClientSecretInvalid) || - errors.Is(err, common.ErrOidcInvalidAuthorizationCode) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -132,14 +113,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) { token := strings.Split(c.GetHeader("Authorization"), " ")[1] jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token) if err != nil { - utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error()) + c.Error(err) return } userID := jwtClaims.Subject clientId := jwtClaims.Audience[0] claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -150,7 +131,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) { clientId := c.Param("id") client, err := oc.oidcService.GetClient(clientId) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -171,7 +152,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) { } } - utils.ControllerError(c, err) + c.Error(err) } func (oc *OidcController) listClientsHandler(c *gin.Context) { @@ -181,13 +162,13 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var clientsDto []dto.OidcClientDto if err := dto.MapStructList(clients, &clientsDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -200,19 +181,19 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { func (oc *OidcController) createClientHandler(c *gin.Context) { var input dto.OidcClientCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var clientDto dto.OidcClientDto if err := dto.MapStruct(client, &clientDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -222,7 +203,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) { func (oc *OidcController) deleteClientHandler(c *gin.Context) { err := oc.oidcService.DeleteClient(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -232,19 +213,19 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) { func (oc *OidcController) updateClientHandler(c *gin.Context) { var input dto.OidcClientCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } client, err := oc.oidcService.UpdateClient(c.Param("id"), input) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var clientDto dto.OidcClientDto if err := dto.MapStruct(client, &clientDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -254,7 +235,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) { func (oc *OidcController) createClientSecretHandler(c *gin.Context) { secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -264,7 +245,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) { func (oc *OidcController) getClientLogoHandler(c *gin.Context) { imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -275,17 +256,13 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) { func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { file, err := c.FormFile("file") if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) if err != nil { - if errors.Is(err, common.ErrFileTypeNotSupported) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } @@ -295,7 +272,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { err := oc.oidcService.DeleteClientLogo(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/test_controller.go b/backend/internal/controller/test_controller.go index 7fb3081..d61f893 100644 --- a/backend/internal/controller/test_controller.go +++ b/backend/internal/controller/test_controller.go @@ -3,7 +3,6 @@ package controller import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) @@ -19,17 +18,17 @@ type TestController struct { func (tc *TestController) resetAndSeedHandler(c *gin.Context) { if err := tc.TestService.ResetDatabase(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } if err := tc.TestService.ResetApplicationImages(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } if err := tc.TestService.SeedDatabase(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index 2d9760c..d68daf4 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -1,13 +1,10 @@ package controller import ( - "errors" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "golang.org/x/time/rate" "net/http" "strconv" @@ -43,13 +40,13 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var usersDto []dto.UserDto if err := dto.MapStructList(users, &usersDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -62,13 +59,13 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { func (uc *UserController) getUserHandler(c *gin.Context) { user, err := uc.UserService.GetUser(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -78,13 +75,13 @@ func (uc *UserController) getUserHandler(c *gin.Context) { func (uc *UserController) getCurrentUserHandler(c *gin.Context) { user, err := uc.UserService.GetUser(c.GetString("userID")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -93,7 +90,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) { func (uc *UserController) deleteUserHandler(c *gin.Context) { if err := uc.UserService.DeleteUser(c.Param("id")); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -103,23 +100,19 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) { func (uc *UserController) createUserHandler(c *gin.Context) { var input dto.UserCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } user, err := uc.UserService.CreateUser(input) if err != nil { - if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -137,13 +130,13 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { var input dto.OneTimeAccessTokenCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -153,17 +146,13 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token")) if err != nil { - if errors.Is(err, common.ErrTokenInvalidOrExpired) { - utils.CustomControllerError(c, http.StatusUnauthorized, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -174,17 +163,13 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { user, token, err := uc.UserService.SetupInitialAdmin() if err != nil { - if errors.Is(err, common.ErrSetupAlreadyCompleted) { - utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -195,7 +180,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { var input dto.UserCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -208,17 +193,13 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser) if err != nil { - if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/user_group_controller.go b/backend/internal/controller/user_group_controller.go index 0012f80..e7fcbb3 100644 --- a/backend/internal/controller/user_group_controller.go +++ b/backend/internal/controller/user_group_controller.go @@ -1,16 +1,13 @@ package controller import ( - "errors" "net/http" "strconv" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" ) func NewUserGroupController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, userGroupService *service.UserGroupService) { @@ -37,7 +34,7 @@ func (ugc *UserGroupController) list(c *gin.Context) { groups, pagination, err := ugc.UserGroupService.List(searchTerm, page, pageSize) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -45,12 +42,12 @@ func (ugc *UserGroupController) list(c *gin.Context) { for i, group := range groups { var groupDto dto.UserGroupDtoWithUserCount if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } groupsDto[i] = groupDto @@ -65,13 +62,13 @@ func (ugc *UserGroupController) list(c *gin.Context) { func (ugc *UserGroupController) get(c *gin.Context) { group, err := ugc.UserGroupService.Get(c.Param("id")) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -81,23 +78,19 @@ func (ugc *UserGroupController) get(c *gin.Context) { func (ugc *UserGroupController) create(c *gin.Context) { var input dto.UserGroupCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } group, err := ugc.UserGroupService.Create(input) if err != nil { - if errors.Is(err, common.ErrNameAlreadyInUse) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -107,23 +100,19 @@ func (ugc *UserGroupController) create(c *gin.Context) { func (ugc *UserGroupController) update(c *gin.Context) { var input dto.UserGroupCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } group, err := ugc.UserGroupService.Update(c.Param("id"), input) if err != nil { - if errors.Is(err, common.ErrNameAlreadyInUse) { - utils.CustomControllerError(c, http.StatusConflict, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -132,7 +121,7 @@ func (ugc *UserGroupController) update(c *gin.Context) { func (ugc *UserGroupController) delete(c *gin.Context) { if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -142,19 +131,19 @@ func (ugc *UserGroupController) delete(c *gin.Context) { func (ugc *UserGroupController) updateUsers(c *gin.Context) { var input dto.UserGroupUpdateUsersDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var groupDto dto.UserGroupDtoWithUsers if err := dto.MapStruct(group, &groupDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go index e0c6bc1..bff2c29 100644 --- a/backend/internal/controller/webauthn_controller.go +++ b/backend/internal/controller/webauthn_controller.go @@ -1,17 +1,15 @@ package controller import ( - "errors" "github.com/go-webauthn/webauthn/protocol" + "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "net/http" "time" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "golang.org/x/time/rate" ) @@ -38,7 +36,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { userID := c.GetString("userID") options, err := wc.webAuthnService.BeginRegistration(userID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -49,20 +47,20 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { sessionID, err := c.Cookie("session_id") if err != nil { - utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing") + c.Error(&common.MissingSessionIdError{}) return } userID := c.GetString("userID") credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var credentialDto dto.WebauthnCredentialDto if err := dto.MapStruct(credential, &credentialDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -72,7 +70,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { options, err := wc.webAuthnService.BeginLogin() if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -83,13 +81,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { sessionID, err := c.Cookie("session_id") if err != nil { - utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing") + c.Error(&common.MissingSessionIdError{}) return } credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -97,17 +95,13 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { user, token, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent()) if err != nil { - if errors.Is(err, common.ErrInvalidCredentials) { - utils.CustomControllerError(c, http.StatusUnauthorized, err.Error()) - } else { - utils.ControllerError(c, err) - } + c.Error(err) return } var userDto dto.UserDto if err := dto.MapStruct(user, &userDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -119,13 +113,13 @@ func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { userID := c.GetString("userID") credentials, err := wc.webAuthnService.ListCredentials(userID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var credentialDtos []dto.WebauthnCredentialDto if err := dto.MapStructList(credentials, &credentialDtos); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -138,7 +132,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { err := wc.webAuthnService.DeleteCredential(userID, credentialID) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } @@ -151,19 +145,19 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) { var input dto.WebauthnCredentialUpdateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } var credentialDto dto.WebauthnCredentialDto if err := dto.MapStruct(credential, &credentialDto); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/controller/well_known_controller.go b/backend/internal/controller/well_known_controller.go index c2b5e2c..b4ec512 100644 --- a/backend/internal/controller/well_known_controller.go +++ b/backend/internal/controller/well_known_controller.go @@ -4,7 +4,6 @@ import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) @@ -21,7 +20,7 @@ type WellKnownController struct { func (wkc *WellKnownController) jwksHandler(c *gin.Context) { jwk, err := wkc.jwtService.GetJWK() if err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/dto/custom_claim_dto.go b/backend/internal/dto/custom_claim_dto.go new file mode 100644 index 0000000..f80f9a6 --- /dev/null +++ b/backend/internal/dto/custom_claim_dto.go @@ -0,0 +1,8 @@ +package dto + +type CustomClaimDto struct { + Key string `json:"key" binding:"required,max=20"` + Value string `json:"value" binding:"required,max=10000"` +} + +type CustomClaimCreateDto = CustomClaimDto diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 1f34001..04521e2 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -3,12 +3,13 @@ package dto import "time" type UserDto struct { - ID string `json:"id"` - Username string `json:"username"` - Email string `json:"email" ` - FirstName string `json:"firstName"` - LastName string `json:"lastName"` - IsAdmin bool `json:"isAdmin"` + ID string `json:"id"` + Username string `json:"username"` + Email string `json:"email" ` + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + IsAdmin bool `json:"isAdmin"` + CustomClaims []CustomClaimDto `json:"customClaims"` } type UserCreateDto struct { @@ -16,7 +17,7 @@ type UserCreateDto struct { Email string `json:"email" binding:"required,email"` FirstName string `json:"firstName" binding:"required,min=3,max=30"` LastName string `json:"lastName" binding:"required,min=3,max=30"` - IsAdmin bool `json:"isAdmin"` + IsAdmin bool `json:"isAdmin" binding:"required"` } type OneTimeAccessTokenCreateDto struct { diff --git a/backend/internal/utils/controller_error_util.go b/backend/internal/middleware/error_handler.go similarity index 51% rename from backend/internal/utils/controller_error_util.go rename to backend/internal/middleware/error_handler.go index 7d62ce7..e02361a 100644 --- a/backend/internal/utils/controller_error_util.go +++ b/backend/internal/middleware/error_handler.go @@ -1,37 +1,69 @@ -package utils +package middleware import ( "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" + "github.com/stonith404/pocket-id/backend/internal/common" "gorm.io/gorm" "log" "net/http" "strings" ) -import ( - "fmt" -) +type ErrorHandlerMiddleware struct{} -func ControllerError(c *gin.Context, err error) { - // Check for record not found errors - if errors.Is(err, gorm.ErrRecordNotFound) { - CustomControllerError(c, http.StatusNotFound, "Record not found") - return - } +func NewErrorHandlerMiddleware() *ErrorHandlerMiddleware { + return &ErrorHandlerMiddleware{} +} + +func (m *ErrorHandlerMiddleware) Add() gin.HandlerFunc { + return func(c *gin.Context) { + c.Next() + for _, err := range c.Errors { + + // Check for record not found errors + if errors.Is(err, gorm.ErrRecordNotFound) { + errorResponse(c, http.StatusNotFound, "Record not found") + return + } + + // Check for validation errors + var validationErrors validator.ValidationErrors + if errors.As(err, &validationErrors) { + message := handleValidationError(validationErrors) + errorResponse(c, http.StatusBadRequest, message) + return + } + + // Check for slice validation errors + var sliceValidationErrors binding.SliceValidationError + if errors.As(err, &sliceValidationErrors) { + if errors.As(sliceValidationErrors[0], &validationErrors) { + message := handleValidationError(validationErrors) + errorResponse(c, http.StatusBadRequest, message) + return + } + } - // Check for validation errors - var validationErrors validator.ValidationErrors - if errors.As(err, &validationErrors) { - message := handleValidationError(validationErrors) - CustomControllerError(c, http.StatusBadRequest, message) - return + var appErr common.AppError + if errors.As(err, &appErr) { + errorResponse(c, appErr.HttpStatusCode(), appErr.Error()) + return + } + log.Println(err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) + } } +} - log.Println(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) +func errorResponse(c *gin.Context, statusCode int, message string) { + // Capitalize the first letter of the message + message = strings.ToUpper(message[:1]) + message[1:] + c.JSON(statusCode, gin.H{"error": message}) } func handleValidationError(validationErrors validator.ValidationErrors) string { @@ -67,9 +99,3 @@ func handleValidationError(validationErrors validator.ValidationErrors) string { return combinedErrors } - -func CustomControllerError(c *gin.Context, statusCode int, message string) { - // Capitalize the first letter of the message - message = strings.ToUpper(message[:1]) + message[1:] - c.JSON(statusCode, gin.H{"error": message}) -} diff --git a/backend/internal/middleware/file_size_limit.go b/backend/internal/middleware/file_size_limit.go index 7503acb..32c7363 100644 --- a/backend/internal/middleware/file_size_limit.go +++ b/backend/internal/middleware/file_size_limit.go @@ -3,7 +3,7 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/common" "net/http" ) @@ -17,8 +17,8 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc { return func(c *gin.Context) { c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) if err := c.Request.ParseMultipartForm(maxSize); err != nil { - utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize))) - c.Abort() + err = &common.FileTooLargeError{MaxSize: formatFileSize(maxSize)} + c.Error(err) return } c.Next() diff --git a/backend/internal/middleware/jwt_auth.go b/backend/internal/middleware/jwt_auth.go index 9416d5a..36be4cc 100644 --- a/backend/internal/middleware/jwt_auth.go +++ b/backend/internal/middleware/jwt_auth.go @@ -2,9 +2,8 @@ package middleware import ( "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/service" - "github.com/stonith404/pocket-id/backend/internal/utils" - "net/http" "strings" ) @@ -29,8 +28,7 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { c.Next() return } else { - utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in") - c.Abort() + c.Error(&common.NotSignedInError{}) return } } @@ -40,14 +38,14 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { c.Next() return } else if err != nil { - utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in") + c.Error(&common.NotSignedInError{}) c.Abort() return } // Check if the user is an admin if adminOnly && !claims.IsAdmin { - utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource") + c.Error(&common.MissingPermissionError{}) c.Abort() return } diff --git a/backend/internal/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go index 494ee06..f9686a6 100644 --- a/backend/internal/middleware/rate_limit.go +++ b/backend/internal/middleware/rate_limit.go @@ -2,8 +2,6 @@ package middleware import ( "github.com/stonith404/pocket-id/backend/internal/common" - "github.com/stonith404/pocket-id/backend/internal/utils" - "net/http" "sync" "time" @@ -33,8 +31,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc { limiter := getLimiter(ip, limit, burst) if !limiter.Allow() { - utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.") - c.Abort() + c.Error(&common.TooManyRequestsError{}) return } diff --git a/backend/internal/model/custom_claim.go b/backend/internal/model/custom_claim.go new file mode 100644 index 0000000..cdce3bc --- /dev/null +++ b/backend/internal/model/custom_claim.go @@ -0,0 +1,10 @@ +package model + +type CustomClaim struct { + Base + + Key string + Value string + + UserID string +} diff --git a/backend/internal/model/user.go b/backend/internal/model/user.go index 8cb6f0b..137838c 100644 --- a/backend/internal/model/user.go +++ b/backend/internal/model/user.go @@ -15,8 +15,9 @@ type User struct { LastName string IsAdmin bool - UserGroups []UserGroup `gorm:"many2many:user_groups_users;"` - Credentials []WebauthnCredential + CustomClaims []CustomClaim + UserGroups []UserGroup `gorm:"many2many:user_groups_users;"` + Credentials []WebauthnCredential } func (u User) WebAuthnID() []byte { return []byte(u.ID) } diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index b28e951..a5bdc3b 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -153,7 +153,7 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image fileType := utils.GetFileExtension(uploadedFile.Filename) mimeType := utils.GetImageMimeType(fileType) if mimeType == "" { - return common.ErrFileTypeNotSupported + return &common.FileTypeNotSupportedError{} } // Delete the old image if it has a different file type diff --git a/backend/internal/service/custom_claim_service.go b/backend/internal/service/custom_claim_service.go new file mode 100644 index 0000000..c4752e2 --- /dev/null +++ b/backend/internal/service/custom_claim_service.go @@ -0,0 +1,111 @@ +package service + +import ( + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" + "github.com/stonith404/pocket-id/backend/internal/model" + "gorm.io/gorm" +) + +// Reserved claims +var reservedClaims = map[string]struct{}{ + "given_name": {}, + "family_name": {}, + "name": {}, + "email": {}, + "preferred_username": {}, + "groups": {}, + "sub": {}, + "iss": {}, + "aud": {}, + "exp": {}, + "iat": {}, + "auth_time": {}, + "nonce": {}, + "acr": {}, + "amr": {}, + "azp": {}, + "nbf": {}, + "jti": {}, +} + +type CustomClaimService struct { + db *gorm.DB +} + +func NewCustomClaimService(db *gorm.DB) *CustomClaimService { + return &CustomClaimService{db: db} +} + +// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username +func isReservedClaim(key string) bool { + _, ok := reservedClaims[key] + return ok +} + +func (s *CustomClaimService) UpdateUserCustomClaims(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + var existingClaims []model.CustomClaim + err := s.db.Where("user_id = ?", userID).Find(&existingClaims).Error + if err != nil { + return nil, err + } + + // Delete claims that are not in the new list + for _, existingClaim := range existingClaims { + found := false + for _, claim := range claims { + if claim.Key == existingClaim.Key { + found = true + break + } + } + if !found { + err = s.db.Delete(&existingClaim).Error + if err != nil { + return nil, err + } + } + } + + // Add or update claims + for _, claim := range claims { + if isReservedClaim(claim.Key) { + return nil, &common.ReservedClaimError{Key: claim.Key} + } + // Update the claim if it already exists or create a new one + err = s.db.Where("user_id = ? AND key = ?", userID, claim.Key).Assign(&model.CustomClaim{ + Key: claim.Key, + Value: claim.Value, + UserID: userID, + }).FirstOrCreate(&model.CustomClaim{}).Error + if err != nil { + return nil, err + } + } + + // Get the updated claims + var updatedClaims []model.CustomClaim + err = s.db.Where("user_id = ?", userID).Find(&updatedClaims).Error + if err != nil { + return nil, err + } + + return updatedClaims, nil +} + +func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) { + var customClaims []model.CustomClaim + err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error + return customClaims, err +} + +func (s *CustomClaimService) GetSuggestions() ([]string, error) { + var customClaimsKeys []string + + err := s.db.Model(&model.CustomClaim{}). + Group("key"). + Order("COUNT(*) DESC"). + Pluck("key", &customClaimsKeys).Error + + return customClaimsKeys, err +} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 64fbb5b..70be994 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -18,18 +18,20 @@ import ( ) type OidcService struct { - db *gorm.DB - jwtService *JwtService - appConfigService *AppConfigService - auditLogService *AuditLogService + db *gorm.DB + jwtService *JwtService + appConfigService *AppConfigService + auditLogService *AuditLogService + customClaimService *CustomClaimService } -func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService) *OidcService { +func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService { return &OidcService{ - db: db, - jwtService: jwtService, - appConfigService: appConfigService, - auditLogService: auditLogService, + db: db, + jwtService: jwtService, + appConfigService: appConfigService, + auditLogService: auditLogService, + customClaimService: customClaimService, } } @@ -38,7 +40,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID) if userAuthorizedOIDCClient.Scope != input.Scope { - return "", "", common.ErrOidcMissingAuthorization + return "", "", &common.OidcMissingAuthorizationError{} } callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL) @@ -93,11 +95,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) { if grantType != "authorization_code" { - return "", "", common.ErrOidcGrantTypeNotSupported + return "", "", &common.OidcGrantTypeNotSupportedError{} } if clientID == "" || clientSecret == "" { - return "", "", common.ErrOidcMissingClientCredentials + return "", "", &common.OidcMissingClientCredentialsError{} } var client model.OidcClient @@ -107,17 +109,17 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) if err != nil { - return "", "", common.ErrOidcClientSecretInvalid + return "", "", &common.OidcClientSecretInvalidError{} } var authorizationCodeMetaData model.OidcAuthorizationCode err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error if err != nil { - return "", "", common.ErrOidcInvalidAuthorizationCode + return "", "", &common.OidcInvalidAuthorizationCodeError{} } if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { - return "", "", common.ErrOidcInvalidAuthorizationCode + return "", "", &common.OidcInvalidAuthorizationCodeError{} } userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) @@ -249,7 +251,7 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) { func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error { fileType := utils.GetFileExtension(file.Filename) if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { - return common.ErrFileTypeNotSupported + return &common.FileTypeNotSupportedError{} } imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType) @@ -333,9 +335,20 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma } if strings.Contains(scope, "profile") { + // Add profile claims for k, v := range profileClaims { claims[k] = v } + + // Add custom claims + customClaims, err := s.customClaimService.GetCustomClaimsForUser(userID) + if err != nil { + return nil, err + } + + for _, customClaim := range customClaims { + claims[customClaim.Key] = customClaim.Value + } } if strings.Contains(scope, "email") { claims["email"] = user.Email @@ -374,5 +387,5 @@ func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackU return inputCallbackURL, nil } - return "", common.ErrOidcInvalidCallbackURL + return "", &common.OidcInvalidCallbackURLError{} } diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index 6dbd9ad..e164eb9 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -50,7 +50,7 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use if err := s.db.Preload("Users").Create(&group).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return model.UserGroup{}, common.ErrNameAlreadyInUse + return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} } return model.UserGroup{}, err } @@ -68,7 +68,7 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto) (grou if err := s.db.Preload("Users").Save(&group).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return model.UserGroup{}, common.ErrNameAlreadyInUse + return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} } return model.UserGroup{}, err } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 0c94a3b..e00f5fe 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -35,7 +35,7 @@ func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]mo func (s *UserService) GetUser(userID string) (model.User, error) { var user model.User - err := s.db.Where("id = ?", userID).First(&user).Error + err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error return user, err } @@ -111,7 +111,7 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, stri var oneTimeAccessToken model.OneTimeAccessToken if err := s.db.Where("token = ? AND expires_at > ?", token, time.Now().Unix()).Preload("User").First(&oneTimeAccessToken).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return model.User{}, "", common.ErrTokenInvalidOrExpired + return model.User{}, "", &common.TokenInvalidOrExpiredError{} } return model.User{}, "", err } @@ -133,7 +133,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { return model.User{}, "", err } if userCount > 1 { - return model.User{}, "", common.ErrSetupAlreadyCompleted + return model.User{}, "", &common.SetupAlreadyCompletedError{} } user := model.User{ @@ -149,7 +149,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { } if len(user.Credentials) > 0 { - return model.User{}, "", common.ErrSetupAlreadyCompleted + return model.User{}, "", &common.SetupAlreadyCompletedError{} } token, err := s.jwtService.GenerateAccessToken(user) @@ -163,11 +163,11 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { func (s *UserService) checkDuplicatedFields(user model.User) error { var existingUser model.User if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { - return common.ErrEmailTaken + return &common.AlreadyInUseError{Property: "email"} } if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { - return common.ErrUsernameTaken + return &common.AlreadyInUseError{Property: "username"} } return nil diff --git a/backend/migrations/20241024064959_custom_claims.down.sql b/backend/migrations/20241024064959_custom_claims.down.sql new file mode 100644 index 0000000..113dcb5 --- /dev/null +++ b/backend/migrations/20241024064959_custom_claims.down.sql @@ -0,0 +1 @@ +DROP TABLE custom_claims; \ No newline at end of file diff --git a/backend/migrations/20241024064959_custom_claims.up.sql b/backend/migrations/20241024064959_custom_claims.up.sql new file mode 100644 index 0000000..68451dc --- /dev/null +++ b/backend/migrations/20241024064959_custom_claims.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE custom_claims +( + id TEXT NOT NULL PRIMARY KEY, + created_at DATETIME, + key TEXT NOT NULL, + value TEXT NOT NULL, + + user_id TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE, + + CONSTRAINT unique_key_user UNIQUE (key, user_id) +); \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 6feb302..df274b5 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1,12 +1,12 @@ { "name": "pocket-id-frontend", - "version": "0.9.0", + "version": "0.10.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "pocket-id-frontend", - "version": "0.9.0", + "version": "0.10.0", "dependencies": { "@simplewebauthn/browser": "^10.0.0", "axios": "^1.7.7", diff --git a/frontend/src/lib/components/form-input.svelte b/frontend/src/lib/components/form-input.svelte index a1021ad..7aeb578 100644 --- a/frontend/src/lib/components/form-input.svelte +++ b/frontend/src/lib/components/form-input.svelte @@ -16,7 +16,7 @@ ...restProps }: HTMLAttributes & { input?: FormInput; - label: string; + label?: string; description?: string; disabled?: boolean; type?: 'text' | 'password' | 'email' | 'number' | 'checkbox'; @@ -24,15 +24,17 @@ children?: Snippet; } = $props(); - const id = label.toLowerCase().replace(/ /g, '-'); + const id = label?.toLowerCase().replace(/ /g, '-');
- + {#if label} + + {/if} {#if description}

{description}

{/if} -
+
{#if children} {@render children()} {:else if input} diff --git a/frontend/src/lib/components/ui/popover/index.ts b/frontend/src/lib/components/ui/popover/index.ts new file mode 100644 index 0000000..63aecf9 --- /dev/null +++ b/frontend/src/lib/components/ui/popover/index.ts @@ -0,0 +1,17 @@ +import { Popover as PopoverPrimitive } from "bits-ui"; +import Content from "./popover-content.svelte"; +const Root = PopoverPrimitive.Root; +const Trigger = PopoverPrimitive.Trigger; +const Close = PopoverPrimitive.Close; + +export { + Root, + Content, + Trigger, + Close, + // + Root as Popover, + Content as PopoverContent, + Trigger as PopoverTrigger, + Close as PopoverClose, +}; diff --git a/frontend/src/lib/components/ui/popover/popover-content.svelte b/frontend/src/lib/components/ui/popover/popover-content.svelte new file mode 100644 index 0000000..5bad4b7 --- /dev/null +++ b/frontend/src/lib/components/ui/popover/popover-content.svelte @@ -0,0 +1,22 @@ + + + + + diff --git a/frontend/src/lib/services/custom-claim-service.ts b/frontend/src/lib/services/custom-claim-service.ts new file mode 100644 index 0000000..36b476f --- /dev/null +++ b/frontend/src/lib/services/custom-claim-service.ts @@ -0,0 +1,15 @@ +import type { CustomClaim } from '$lib/types/custom-claim.type'; +import type { User } from 'lucide-svelte'; +import APIService from './api-service'; + +export default class CustomClaimService extends APIService { + async getSuggestions() { + const res = await this.api.get('/custom-claims/suggestions'); + return res.data as string[]; + } + + async updateUserCustomClaims(userId: string, claims: CustomClaim[]) { + const res = await this.api.put(`/custom-claims/user/${userId}`, claims); + return res.data as User; + } +} diff --git a/frontend/src/lib/types/custom-claim.type.ts b/frontend/src/lib/types/custom-claim.type.ts new file mode 100644 index 0000000..24a3659 --- /dev/null +++ b/frontend/src/lib/types/custom-claim.type.ts @@ -0,0 +1,4 @@ +export type CustomClaim = { + key: string; + value: string; +}; diff --git a/frontend/src/lib/types/user.type.ts b/frontend/src/lib/types/user.type.ts index f22368e..1157ecf 100644 --- a/frontend/src/lib/types/user.type.ts +++ b/frontend/src/lib/types/user.type.ts @@ -1,3 +1,5 @@ +import type { CustomClaim } from './custom-claim.type'; + export type User = { id: string; username: string; @@ -5,6 +7,7 @@ export type User = { firstName: string; lastName: string; isAdmin: boolean; + customClaims: CustomClaim[]; }; -export type UserCreate = Omit; +export type UserCreate = Omit; diff --git a/frontend/src/routes/settings/admin/users/[id]/+page.svelte b/frontend/src/routes/settings/admin/users/[id]/+page.svelte index fc99a66..fde9d41 100644 --- a/frontend/src/routes/settings/admin/users/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/users/[id]/+page.svelte @@ -1,16 +1,20 @@ @@ -37,10 +50,25 @@
- {user.firstName} {user.lastName} + General Details - + + + + Custom Claims + + Custom claims are key-value pairs that can be used to store additional information about a + user. These claims will be included in the ID token if the scope "profile" is requested. + + + + +
+ +
+
+
diff --git a/frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte b/frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte new file mode 100644 index 0000000..2ec313e --- /dev/null +++ b/frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte @@ -0,0 +1,97 @@ + + +
+ (isInputFocused = true)} + onblur={() => (isInputFocused = false)} + /> + {}} + closeOnOutsideClick={false} + closeOnEscape={false} + > + +
diff --git a/frontend/src/routes/settings/admin/users/[id]/custom-claim-input.svelte b/frontend/src/routes/settings/admin/users/[id]/custom-claim-input.svelte new file mode 100644 index 0000000..e1d0c52 --- /dev/null +++ b/frontend/src/routes/settings/admin/users/[id]/custom-claim-input.svelte @@ -0,0 +1,78 @@ + + +
+ +
+ {#each customClaims as _, i} +
+ + + +
+ {/each} +
+
+ {#if error} +

{error}

+ {/if} + {#if customClaims.length < limit} + + {/if} +
From 8b4e84ba7eb5bdb78664b745d62674cca5f78271 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Sun, 27 Oct 2024 11:56:26 +0100 Subject: [PATCH 2/4] remove required admin binding from isAdmin --- backend/internal/dto/user_dto.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 04521e2..19af71d 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -17,7 +17,7 @@ type UserCreateDto struct { Email string `json:"email" binding:"required,email"` FirstName string `json:"firstName" binding:"required,min=3,max=30"` LastName string `json:"lastName" binding:"required,min=3,max=30"` - IsAdmin bool `json:"isAdmin" binding:"required"` + IsAdmin bool `json:"isAdmin"` } type OneTimeAccessTokenCreateDto struct { From cb27e8c76dbba6f7232b7be53c5b325966726daf Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Mon, 28 Oct 2024 14:08:54 +0100 Subject: [PATCH 3/4] add custom claims to user groups --- backend/internal/common/errors.go | 10 ++ .../controller/custom_claim_controller.go | 31 +++++- .../internal/controller/test_controller.go | 2 +- backend/internal/dto/user_group_dto.go | 22 ++-- backend/internal/model/custom_claim.go | 3 +- backend/internal/model/user_group.go | 1 + .../internal/service/custom_claim_service.go | 102 ++++++++++++++++-- backend/internal/service/oidc_service.go | 2 +- .../internal/service/user_group_service.go | 4 +- .../20241024064959_custom_claims.up.sql | 12 --- ... => 20241028064959_custom_claims.down.sql} | 0 .../20241028064959_custom_claims.up.sql | 15 +++ .../components}/auto-complete-input.svelte | 30 ++++-- .../components/custom-claims-input.svelte} | 0 .../src/lib/services/custom-claim-service.ts | 8 +- frontend/src/lib/types/user-group.type.ts | 2 + .../admin/user-groups/[id]/+page.svelte | 31 +++++- .../settings/admin/users/[id]/+page.svelte | 4 +- 18 files changed, 225 insertions(+), 54 deletions(-) delete mode 100644 backend/migrations/20241024064959_custom_claims.up.sql rename backend/migrations/{20241024064959_custom_claims.down.sql => 20241028064959_custom_claims.down.sql} (100%) create mode 100644 backend/migrations/20241028064959_custom_claims.up.sql rename frontend/src/{routes/settings/admin/users/[id] => lib/components}/auto-complete-input.svelte (71%) rename frontend/src/{routes/settings/admin/users/[id]/custom-claim-input.svelte => lib/components/custom-claims-input.svelte} (100%) diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 6cb79a5..0056ce7 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -134,3 +134,13 @@ func (e *ReservedClaimError) Error() string { } func (e *ReservedClaimError) HttpStatusCode() int { return http.StatusBadRequest } + +type DuplicateClaimError struct { + Key string +} + +func (e *DuplicateClaimError) Error() string { + return fmt.Sprintf("Claim %s is already defined", e.Key) +} + +func (e *DuplicateClaimError) HttpStatusCode() int { return http.StatusBadRequest } diff --git a/backend/internal/controller/custom_claim_controller.go b/backend/internal/controller/custom_claim_controller.go index 57ed4f8..ca28261 100644 --- a/backend/internal/controller/custom_claim_controller.go +++ b/backend/internal/controller/custom_claim_controller.go @@ -11,7 +11,8 @@ import ( func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) { wkc := &CustomClaimController{customClaimService: customClaimService} group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler) - group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.updateUserCustomClaimsHandler) + group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserHandler) + group.PUT("/custom-claims/user-group/:userGroupId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserGroupHandler) } type CustomClaimController struct { @@ -28,7 +29,7 @@ func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) { c.JSON(http.StatusOK, claims) } -func (ccc *CustomClaimController) updateUserCustomClaimsHandler(c *gin.Context) { +func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) { var input []dto.CustomClaimCreateDto if err := c.ShouldBindJSON(&input); err != nil { @@ -37,7 +38,31 @@ func (ccc *CustomClaimController) updateUserCustomClaimsHandler(c *gin.Context) } userId := c.Param("userId") - claims, err := ccc.customClaimService.UpdateUserCustomClaims(userId, input) + claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input) + if err != nil { + c.Error(err) + return + } + + var customClaimsDto []dto.CustomClaimDto + if err := dto.MapStructList(claims, &customClaimsDto); err != nil { + c.Error(err) + return + } + + c.JSON(http.StatusOK, customClaimsDto) +} + +func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) { + var input []dto.CustomClaimCreateDto + + if err := c.ShouldBindJSON(&input); err != nil { + c.Error(err) + return + } + + userId := c.Param("userGroupId") + claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userId, input) if err != nil { c.Error(err) return diff --git a/backend/internal/controller/test_controller.go b/backend/internal/controller/test_controller.go index 55fa852..a613e4c 100644 --- a/backend/internal/controller/test_controller.go +++ b/backend/internal/controller/test_controller.go @@ -33,7 +33,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) { } if err := tc.TestService.ResetAppConfig(); err != nil { - utils.ControllerError(c, err) + c.Error(err) return } diff --git a/backend/internal/dto/user_group_dto.go b/backend/internal/dto/user_group_dto.go index 424c61c..daef04b 100644 --- a/backend/internal/dto/user_group_dto.go +++ b/backend/internal/dto/user_group_dto.go @@ -3,19 +3,21 @@ package dto import "time" type UserGroupDtoWithUsers struct { - ID string `json:"id"` - FriendlyName string `json:"friendlyName"` - Name string `json:"name"` - Users []UserDto `json:"users"` - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + FriendlyName string `json:"friendlyName"` + Name string `json:"name"` + CustomClaims []CustomClaimDto `json:"customClaims"` + Users []UserDto `json:"users"` + CreatedAt time.Time `json:"createdAt"` } type UserGroupDtoWithUserCount struct { - ID string `json:"id"` - FriendlyName string `json:"friendlyName"` - Name string `json:"name"` - UserCount int64 `json:"userCount"` - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + FriendlyName string `json:"friendlyName"` + Name string `json:"name"` + CustomClaims []CustomClaimDto `json:"customClaims"` + UserCount int64 `json:"userCount"` + CreatedAt time.Time `json:"createdAt"` } type UserGroupCreateDto struct { diff --git a/backend/internal/model/custom_claim.go b/backend/internal/model/custom_claim.go index cdce3bc..c47f934 100644 --- a/backend/internal/model/custom_claim.go +++ b/backend/internal/model/custom_claim.go @@ -6,5 +6,6 @@ type CustomClaim struct { Key string Value string - UserID string + UserID *string + UserGroupID *string } diff --git a/backend/internal/model/user_group.go b/backend/internal/model/user_group.go index 8559016..15648d4 100644 --- a/backend/internal/model/user_group.go +++ b/backend/internal/model/user_group.go @@ -5,4 +5,5 @@ type UserGroup struct { FriendlyName string Name string `gorm:"unique"` Users []User `gorm:"many2many:user_groups_users;"` + CustomClaims []CustomClaim } diff --git a/backend/internal/service/custom_claim_service.go b/backend/internal/service/custom_claim_service.go index c4752e2..5521999 100644 --- a/backend/internal/service/custom_claim_service.go +++ b/backend/internal/service/custom_claim_service.go @@ -43,9 +43,37 @@ func isReservedClaim(key string) bool { return ok } -func (s *CustomClaimService) UpdateUserCustomClaims(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { +// idType is the type of the id used to identify the user or user group +type idType string + +const ( + UserID idType = "user_id" + UserGroupID idType = "user_group_id" +) + +// UpdateCustomClaimsForUser updates the custom claims for a user +func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + return s.updateCustomClaims(UserID, userID, claims) +} + +// UpdateCustomClaimsForUserGroup updates the custom claims for a user group +func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + return s.updateCustomClaims(UserGroupID, userGroupID, claims) +} + +// updateCustomClaims updates the custom claims for a user or user group +func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + // Check for duplicate keys in the claims slice + seenKeys := make(map[string]bool) + for _, claim := range claims { + if seenKeys[claim.Key] { + return nil, &common.DuplicateClaimError{Key: claim.Key} + } + seenKeys[claim.Key] = true + } + var existingClaims []model.CustomClaim - err := s.db.Where("user_id = ?", userID).Find(&existingClaims).Error + err := s.db.Where(string(idType), value).Find(&existingClaims).Error if err != nil { return nil, err } @@ -72,12 +100,19 @@ func (s *CustomClaimService) UpdateUserCustomClaims(userID string, claims []dto. if isReservedClaim(claim.Key) { return nil, &common.ReservedClaimError{Key: claim.Key} } + customClaim := model.CustomClaim{ + Key: claim.Key, + Value: claim.Value, + } + + if idType == UserID { + customClaim.UserID = &value + } else if idType == UserGroupID { + customClaim.UserGroupID = &value + } + // Update the claim if it already exists or create a new one - err = s.db.Where("user_id = ? AND key = ?", userID, claim.Key).Assign(&model.CustomClaim{ - Key: claim.Key, - Value: claim.Value, - UserID: userID, - }).FirstOrCreate(&model.CustomClaim{}).Error + err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error if err != nil { return nil, err } @@ -85,7 +120,7 @@ func (s *CustomClaimService) UpdateUserCustomClaims(userID string, claims []dto. // Get the updated claims var updatedClaims []model.CustomClaim - err = s.db.Where("user_id = ?", userID).Find(&updatedClaims).Error + err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error if err != nil { return nil, err } @@ -99,6 +134,57 @@ func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.Cust return customClaims, err } +func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) { + var customClaims []model.CustomClaim + err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error + return customClaims, err +} + +// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of, +// prioritizing the user's claims over user group claims with the same key. +func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) { + // Get the custom claims of the user + customClaims, err := s.GetCustomClaimsForUser(userID) + if err != nil { + return nil, err + } + + // Store user's claims in a map to prioritize and prevent duplicates + claimsMap := make(map[string]model.CustomClaim) + for _, claim := range customClaims { + claimsMap[claim.Key] = claim + } + + // Get all user groups of the user + var userGroupsOfUser []model.UserGroup + err = s.db.Preload("CustomClaims"). + Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id"). + Where("user_groups_users.user_id = ?", userID). + Find(&userGroupsOfUser).Error + if err != nil { + return nil, err + } + + // Add only non-duplicate custom claims from user groups + for _, userGroup := range userGroupsOfUser { + for _, groupClaim := range userGroup.CustomClaims { + // Only add claim if it does not exist in the user's claims + if _, exists := claimsMap[groupClaim.Key]; !exists { + claimsMap[groupClaim.Key] = groupClaim + } + } + } + + // Convert the claimsMap back to a slice + finalClaims := make([]model.CustomClaim, 0, len(claimsMap)) + for _, claim := range claimsMap { + finalClaims = append(finalClaims, claim) + } + + return finalClaims, nil +} + +// GetSuggestions returns a list of custom claim keys that have been used before func (s *CustomClaimService) GetSuggestions() ([]string, error) { var customClaimsKeys []string diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 2f7033b..9036406 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -342,7 +342,7 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma } // Add custom claims - customClaims, err := s.customClaimService.GetCustomClaimsForUser(userID) + customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID) if err != nil { return nil, err } diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index e164eb9..065f9e7 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -18,7 +18,7 @@ func NewUserGroupService(db *gorm.DB) *UserGroupService { } func (s *UserGroupService) List(name string, page int, pageSize int) (groups []model.UserGroup, response utils.PaginationResponse, err error) { - query := s.db.Model(&model.UserGroup{}) + query := s.db.Preload("CustomClaims").Model(&model.UserGroup{}) if name != "" { query = query.Where("name LIKE ?", "%"+name+"%") @@ -29,7 +29,7 @@ func (s *UserGroupService) List(name string, page int, pageSize int) (groups []m } func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) { - err = s.db.Where("id = ?", id).Preload("Users").First(&group).Error + err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error return group, err } diff --git a/backend/migrations/20241024064959_custom_claims.up.sql b/backend/migrations/20241024064959_custom_claims.up.sql deleted file mode 100644 index 68451dc..0000000 --- a/backend/migrations/20241024064959_custom_claims.up.sql +++ /dev/null @@ -1,12 +0,0 @@ -CREATE TABLE custom_claims -( - id TEXT NOT NULL PRIMARY KEY, - created_at DATETIME, - key TEXT NOT NULL, - value TEXT NOT NULL, - - user_id TEXT NOT NULL, - FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE, - - CONSTRAINT unique_key_user UNIQUE (key, user_id) -); \ No newline at end of file diff --git a/backend/migrations/20241024064959_custom_claims.down.sql b/backend/migrations/20241028064959_custom_claims.down.sql similarity index 100% rename from backend/migrations/20241024064959_custom_claims.down.sql rename to backend/migrations/20241028064959_custom_claims.down.sql diff --git a/backend/migrations/20241028064959_custom_claims.up.sql b/backend/migrations/20241028064959_custom_claims.up.sql new file mode 100644 index 0000000..9d6acc6 --- /dev/null +++ b/backend/migrations/20241028064959_custom_claims.up.sql @@ -0,0 +1,15 @@ +CREATE TABLE custom_claims +( + id TEXT NOT NULL PRIMARY KEY, + created_at DATETIME, + key TEXT NOT NULL, + value TEXT NOT NULL, + + user_id TEXT, + user_group_id TEXT, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE, + FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE, + + CONSTRAINT custom_claims_unique UNIQUE (key, user_id, user_group_id), + CHECK (user_id IS NOT NULL OR user_group_id IS NOT NULL) +); \ No newline at end of file diff --git a/frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte b/frontend/src/lib/components/auto-complete-input.svelte similarity index 71% rename from frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte rename to frontend/src/lib/components/auto-complete-input.svelte index 2ec313e..da5e3c5 100644 --- a/frontend/src/routes/settings/admin/users/[id]/auto-complete-input.svelte +++ b/frontend/src/lib/components/auto-complete-input.svelte @@ -5,35 +5,43 @@ let { value = $bindable(''), placeholder, + suggestionLimit = 5, suggestions - }: { value: string; placeholder: string; suggestions: string[] } = $props(); + }: { + value: string; + placeholder: string; + suggestionLimit?: number; + suggestions: string[]; + } = $props(); - let suggestionResult: string[] = $state(suggestions); + let filteredSuggestions: string[] = $state(suggestions.slice(0, suggestionLimit)); let selectedIndex = $state(-1); let isInputFocused = $state(false); function handleSuggestionClick(suggestion: (typeof suggestions)[0]) { value = suggestion; - suggestionResult = []; + filteredSuggestions = []; } function handleOnInput() { - suggestionResult = suggestions.filter((s) => s.includes(value.toLowerCase())); + filteredSuggestions = suggestions + .filter((s) => s.includes(value.toLowerCase())) + .slice(0, suggestionLimit); } function handleKeydown(e: KeyboardEvent) { if (!isOpen) return; switch (e.key) { case 'ArrowDown': - selectedIndex = Math.min(selectedIndex + 1, suggestionResult.length - 1); + selectedIndex = Math.min(selectedIndex + 1, filteredSuggestions.length - 1); break; case 'ArrowUp': selectedIndex = Math.max(selectedIndex - 1, -1); break; case 'Enter': if (selectedIndex >= 0) { - handleSuggestionClick(suggestionResult[selectedIndex]); + handleSuggestionClick(filteredSuggestions[selectedIndex]); } break; case 'Escape': @@ -42,18 +50,18 @@ } } - let isOpen = $derived(suggestionResult.length > 0 && isInputFocused); + let isOpen = $derived(filteredSuggestions.length > 0 && isInputFocused); $effect(() => { // Reset selection when suggestions change - if (suggestionResult) { + if (filteredSuggestions) { selectedIndex = -1; } });
-
+ + + + Custom Claims + + Custom claims are key-value pairs that can be used to store additional information about a + user. These claims will be included in the ID token if the scope "profile" is requested. + Custom claims defined on the user will be prioritized if there are conflicts. + + + + +
+ +
+
+
diff --git a/frontend/src/routes/settings/admin/users/[id]/+page.svelte b/frontend/src/routes/settings/admin/users/[id]/+page.svelte index fde9d41..657ae91 100644 --- a/frontend/src/routes/settings/admin/users/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/users/[id]/+page.svelte @@ -7,8 +7,8 @@ import { axiosErrorToast } from '$lib/utils/error-util'; import { LucideChevronLeft } from 'lucide-svelte'; import { toast } from 'svelte-sonner'; + import CustomClaimsInput from '../../../../../lib/components/custom-claims-input.svelte'; import UserForm from '../user-form.svelte'; - import CustomClaimsInput from './custom-claim-input.svelte'; let { data } = $props(); let user = $state(data); @@ -50,7 +50,7 @@
- General Details + General From 0abf232ebc4ff415adc1704ee263e2f50d699823 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Mon, 28 Oct 2024 14:37:00 +0100 Subject: [PATCH 4/4] add tests --- .../lib/components/custom-claims-input.svelte | 7 +-- frontend/tests/account-settings.spec.ts | 4 +- frontend/tests/user-group.spec.ts | 36 +++++++++++++ frontend/tests/user-settings.spec.ts | 50 ++++++++++++++++--- 4 files changed, 83 insertions(+), 14 deletions(-) diff --git a/frontend/src/lib/components/custom-claims-input.svelte b/frontend/src/lib/components/custom-claims-input.svelte index e1d0c52..253d56d 100644 --- a/frontend/src/lib/components/custom-claims-input.svelte +++ b/frontend/src/lib/components/custom-claims-input.svelte @@ -45,14 +45,11 @@ suggestions={filteredSuggestions} bind:value={customClaims[i].key} /> - +