Skip to content

Commit

Permalink
feat: custom claims (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
stonith404 authored Oct 28, 2024
1 parent 3350398 commit c056089
Show file tree
Hide file tree
Showing 43 changed files with 1,069 additions and 279 deletions.
5 changes: 4 additions & 1 deletion backend/internal/bootstrap/router_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
userService := service.NewUserService(db, jwtService)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService)
customClaimService := service.NewCustomClaimService(db)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
testService := service.NewTestService(db, appConfigService)
userGroupService := service.NewUserGroupService(db)

r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))

Expand All @@ -59,6 +61,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware)
controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService)
controller.NewCustomClaimController(apiGroup, jwtAuthMiddleware, customClaimService)

// Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" {
Expand Down
159 changes: 143 additions & 16 deletions backend/internal/common/errors.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,146 @@
package common

import "errors"

var (
ErrUsernameTaken = errors.New("username is already taken")
ErrEmailTaken = errors.New("email is already taken")
ErrSetupAlreadyCompleted = errors.New("setup already completed")
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
ErrOidcMissingAuthorization = errors.New("missing authorization")
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
ErrFileTypeNotSupported = errors.New("file type not supported")
ErrInvalidCredentials = errors.New("no user found with provided credentials")
ErrNameAlreadyInUse = errors.New("name is already in use")
import (
"fmt"
"net/http"
)

type AppError interface {
error
HttpStatusCode() int
}

// Custom error types for various conditions

type AlreadyInUseError struct {
Property string
}

func (e *AlreadyInUseError) Error() string {
return fmt.Sprintf("%s is already in use", e.Property)
}
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }

type SetupAlreadyCompletedError struct{}

func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return 400 }

type TokenInvalidOrExpiredError struct{}

func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" }
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 }

type OidcMissingAuthorizationError struct{}

func (e *OidcMissingAuthorizationError) Error() string { return "missing authorization" }
func (e *OidcMissingAuthorizationError) HttpStatusCode() int { return http.StatusForbidden }

type OidcGrantTypeNotSupportedError struct{}

func (e *OidcGrantTypeNotSupportedError) Error() string { return "grant type not supported" }
func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return 400 }

type OidcMissingClientCredentialsError struct{}

func (e *OidcMissingClientCredentialsError) Error() string { return "client id or secret not provided" }
func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return 400 }

type OidcClientSecretInvalidError struct{}

func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }

type OidcInvalidAuthorizationCodeError struct{}

func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 }

type OidcInvalidCallbackURLError struct{}

func (e *OidcInvalidCallbackURLError) Error() string { return "invalid callback URL" }
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 }

type FileTypeNotSupportedError struct{}

func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }

type InvalidCredentialsError struct{}

func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" }
func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 }

type FileTooLargeError struct {
MaxSize string
}

func (e *FileTooLargeError) Error() string {
return fmt.Sprintf("The file can't be larger than %s", e.MaxSize)
}
func (e *FileTooLargeError) HttpStatusCode() int { return http.StatusRequestEntityTooLarge }

type NotSignedInError struct{}

func (e *NotSignedInError) Error() string { return "You are not signed in" }
func (e *NotSignedInError) HttpStatusCode() int { return http.StatusUnauthorized }

type MissingPermissionError struct{}

func (e *MissingPermissionError) Error() string {
return "You don't have permission to perform this action"
}
func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbidden }

type TooManyRequestsError struct{}

func (e *TooManyRequestsError) Error() string {
return "Too many requests. Please wait a while before trying again."
}
func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests }

type ClientIdOrSecretNotProvidedError struct{}

func (e *ClientIdOrSecretNotProvidedError) Error() string {
return "Client id and secret not provided"
}

func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest }

type WrongFileTypeError struct {
ExpectedFileType string
}

func (e *WrongFileTypeError) Error() string {
return fmt.Sprintf("File must be of type %s", e.ExpectedFileType)
}

func (e *WrongFileTypeError) HttpStatusCode() int { return http.StatusBadRequest }

type MissingSessionIdError struct{}

func (e *MissingSessionIdError) Error() string {
return "Missing session id"
}

func (e *MissingSessionIdError) HttpStatusCode() int { return http.StatusBadRequest }

type ReservedClaimError struct {
Key string
}

func (e *ReservedClaimError) Error() string {
return fmt.Sprintf("Claim %s is reserved and can't be used", e.Key)
}

func (e *ReservedClaimError) HttpStatusCode() int { return http.StatusBadRequest }

type DuplicateClaimError struct {
Key string
}

func (e *DuplicateClaimError) Error() string {
return fmt.Sprintf("Claim %s is already defined", e.Key)
}

func (e *DuplicateClaimError) HttpStatusCode() int { return http.StatusBadRequest }
27 changes: 11 additions & 16 deletions backend/internal/controller/app_config_controller.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package controller

import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
Expand Down Expand Up @@ -39,13 +38,13 @@ type AppConfigController struct {
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(false)
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

Expand All @@ -55,13 +54,13 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(true)
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

Expand All @@ -71,19 +70,19 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input)
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

Expand Down Expand Up @@ -136,13 +135,13 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" {
utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico")
c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
return
}
acc.updateImage(c, "favicon", "ico")
Expand All @@ -164,17 +163,13 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file")
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.ControllerError(c, err)
}
c.Error(err)
return
}

Expand Down
5 changes: 2 additions & 3 deletions backend/internal/controller/audit_log_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
)

func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.AuditLogService, jwtAuthMiddleware *middleware.JwtAuthMiddleware) {
Expand All @@ -31,15 +30,15 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
// Fetch audit logs for the user
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, page, pageSize)
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

// Map the audit logs to DTOs
var logsDtos []dto.AuditLogDto
err = dto.MapStructList(logs, &logsDtos)
if err != nil {
utils.ControllerError(c, err)
c.Error(err)
return
}

Expand Down
78 changes: 78 additions & 0 deletions backend/internal/controller/custom_claim_controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package controller

import (
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
"net/http"
)

func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) {
wkc := &CustomClaimController{customClaimService: customClaimService}
group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler)
group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserHandler)
group.PUT("/custom-claims/user-group/:userGroupId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserGroupHandler)
}

type CustomClaimController struct {
customClaimService *service.CustomClaimService
}

func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
claims, err := ccc.customClaimService.GetSuggestions()
if err != nil {
c.Error(err)
return
}

c.JSON(http.StatusOK, claims)
}

func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto

if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}

userId := c.Param("userId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input)
if err != nil {
c.Error(err)
return
}

var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err)
return
}

c.JSON(http.StatusOK, customClaimsDto)
}

func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto

if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}

userId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userId, input)
if err != nil {
c.Error(err)
return
}

var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err)
return
}

c.JSON(http.StatusOK, customClaimsDto)
}
Loading

0 comments on commit c056089

Please sign in to comment.