Skip to content

Commit

Permalink
feat: add user groups
Browse files Browse the repository at this point in the history
  • Loading branch information
stonith404 committed Oct 2, 2024
1 parent 7a54d3a commit 24c948e
Show file tree
Hide file tree
Showing 40 changed files with 1,142 additions and 37 deletions.
2 changes: 2 additions & 0 deletions backend/internal/bootstrap/router_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
userService := service.NewUserService(db, jwtService)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService)
testService := service.NewTestService(db, appConfigService)
userGroupService := service.NewUserGroupService(db)

r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
Expand All @@ -57,6 +58,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware)
controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService)

// Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" {
Expand Down
1 change: 1 addition & 0 deletions backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ var (
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
ErrFileTypeNotSupported = errors.New("file type not supported")
ErrInvalidCredentials = errors.New("no user found with provided credentials")
ErrNameAlreadyInUse = errors.New("name is already in use")
)
162 changes: 162 additions & 0 deletions backend/internal/controller/user_group_controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package controller

import (
"errors"
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
)

func NewUserGroupController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, userGroupService *service.UserGroupService) {
ugc := UserGroupController{
UserGroupService: userGroupService,
}

group.GET("/user-groups", jwtAuthMiddleware.Add(true), ugc.list)
group.GET("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.get)
group.POST("/user-groups", jwtAuthMiddleware.Add(true), ugc.create)
group.PUT("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.update)
group.DELETE("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.delete)
group.PUT("/user-groups/:id/users", jwtAuthMiddleware.Add(true), ugc.updateUsers)
}

type UserGroupController struct {
UserGroupService *service.UserGroupService
}

func (ugc *UserGroupController) list(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
searchTerm := c.Query("search")

groups, pagination, err := ugc.UserGroupService.List(searchTerm, page, pageSize)
if err != nil {
utils.ControllerError(c, err)
return
}

var groupsDto = make([]dto.UserGroupDtoWithUserCount, len(groups))
for i, group := range groups {
var groupDto dto.UserGroupDtoWithUserCount
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID)
if err != nil {
utils.ControllerError(c, err)
return
}
groupsDto[i] = groupDto
}

c.JSON(http.StatusOK, gin.H{
"data": groupsDto,
"pagination": pagination,
})
}

func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Param("id"))
if err != nil {
utils.ControllerError(c, err)
return
}

var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}

c.JSON(http.StatusOK, groupDto)
}

func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}

group, err := ugc.UserGroupService.Create(input)
if err != nil {
if errors.Is(err, common.ErrNameAlreadyInUse) {
utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else {
utils.ControllerError(c, err)
}
return
}

var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}

c.JSON(http.StatusCreated, groupDto)
}

func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}

group, err := ugc.UserGroupService.Update(c.Param("id"), input)
if err != nil {
if errors.Is(err, common.ErrNameAlreadyInUse) {
utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else {
utils.ControllerError(c, err)
}
return
}

var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}

c.JSON(http.StatusOK, groupDto)
}

func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil {
utils.ControllerError(c, err)
return
}

c.Status(http.StatusNoContent)
}

func (ugc *UserGroupController) updateUsers(c *gin.Context) {
var input dto.UserGroupUpdateUsersDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}

group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input)
if err != nil {
utils.ControllerError(c, err)
return
}

var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}

c.JSON(http.StatusOK, groupDto)
}
22 changes: 22 additions & 0 deletions backend/internal/dto/dto_mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,37 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
// Handle direct assignment for simple types
if sourceField.Type() == destField.Type() {
destField.Set(sourceField)

} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
// Handle slices
if sourceField.Type().Elem() == destField.Type().Elem() {
// Direct assignment for slices of primitive types or non-struct elements
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())

for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}

destField.Set(newSlice)

} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
// Recursively map slices of structs
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())

for j := 0; j < sourceField.Len(); j++ {
// Get the element from both source and destination slice
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()

// Recursively map the struct elements
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}

// Set the mapped element in the new slice
newSlice.Index(j).Set(destElem)
}

destField.Set(newSlice)
}
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
Expand Down
32 changes: 32 additions & 0 deletions backend/internal/dto/user_group_dto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package dto

import "time"

type UserGroupDtoWithUsers struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
Users []UserDto `json:"users"`
CreatedAt time.Time `json:"createdAt"`
}

type UserGroupDtoWithUserCount struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
UserCount int64 `json:"userCount"`
CreatedAt time.Time `json:"createdAt"`
}

type UserGroupCreateDto struct {
FriendlyName string `json:"friendlyName" binding:"required,min=3,max=30"`
Name string `json:"name" binding:"required,min=3,max=30,userGroupName"`
}

type UserGroupUpdateUsersDto struct {
UserIDs []string `json:"userIds" binding:"required"`
}

type AssignUserToGroupDto struct {
UserID string `json:"userId" binding:"required"`
}
13 changes: 13 additions & 0 deletions backend/internal/dto/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
return matched
}

var validateUserGroupName validator.Func = func(fl validator.FieldLevel) bool {
// [a-z0-9_] : The group name can only contain lowercase letters, numbers, and underscores
regex := "^[a-z0-9_]+$"
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched
}

func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
Expand All @@ -39,4 +46,10 @@ func init() {
log.Fatalf("Failed to register custom validation: %v", err)
}
}

if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("userGroupName", validateUserGroupName); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
}
1 change: 1 addition & 0 deletions backend/internal/model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type User struct {
LastName string
IsAdmin bool

UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
Credentials []WebauthnCredential
}

Expand Down
8 changes: 8 additions & 0 deletions backend/internal/model/user_group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package model

type UserGroup struct {
Base
FriendlyName string
Name string `gorm:"unique"`
Users []User `gorm:"many2many:user_groups_users;"`
}
10 changes: 8 additions & 2 deletions backend/internal/service/oidc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,21 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {

func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.Preload("User").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
return nil, err
}

user := authorizedOidcClient.User
scope := authorizedOidcClient.Scope

userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups {
userGroups[i] = group.Name
}

claims := map[string]interface{}{
"sub": user.ID,
"sub": user.ID,
"groups": userGroups,
}

if strings.Contains(scope, "email") {
Expand Down
Loading

0 comments on commit 24c948e

Please sign in to comment.