diff --git a/backend/.mockery.private.yml b/backend/.mockery.private.yml index f7a9c90f2..fc563382e 100644 --- a/backend/.mockery.private.yml +++ b/backend/.mockery.private.yml @@ -51,6 +51,13 @@ packages: dir: internal/authn structname: '{{.InterfaceName}}Mock' pkgname: authn + + github.com/asgardeo/thunder/internal/role: + config: + all: true + dir: internal/role + structname: '{{.InterfaceName}}Mock' + pkgname: role filename: "{{.InterfaceName}}_mock_test.go" github.com/asgardeo/thunder/internal/flow/flowexec: diff --git a/backend/.mockery.public.yml b/backend/.mockery.public.yml index 33027341b..10d7a85f2 100644 --- a/backend/.mockery.public.yml +++ b/backend/.mockery.public.yml @@ -268,3 +268,11 @@ packages: structname: '{{.InterfaceName}}Mock' pkgname: userschemamock filename: "{{.InterfaceName}}_mock.go" + + github.com/asgardeo/thunder/internal/group: + config: + all: true + dir: tests/mocks/groupmock + structname: '{{.InterfaceName}}Mock' + pkgname: groupmock + filename: "{{.InterfaceName}}_mock.go" diff --git a/backend/cmd/server/servicemanager.go b/backend/cmd/server/servicemanager.go index a798fddcc..3605694a7 100644 --- a/backend/cmd/server/servicemanager.go +++ b/backend/cmd/server/servicemanager.go @@ -31,6 +31,7 @@ import ( "github.com/asgardeo/thunder/internal/notification" "github.com/asgardeo/thunder/internal/oauth" "github.com/asgardeo/thunder/internal/ou" + "github.com/asgardeo/thunder/internal/role" "github.com/asgardeo/thunder/internal/system/jwt" "github.com/asgardeo/thunder/internal/system/log" "github.com/asgardeo/thunder/internal/system/services" @@ -51,7 +52,8 @@ func registerServices(mux *http.ServeMux) { ouService := ou.Initialize(mux) userSchemaService := userschema.Initialize(mux) userService := user.Initialize(mux, ouService, userSchemaService) - _ = group.Initialize(mux, ouService, userService) + groupService := group.Initialize(mux, ouService, userService) + _ = role.Initialize(mux, userService, groupService, ouService) _ = idp.Initialize(mux) _ = notification.Initialize(mux, jwtService) diff --git a/backend/dbscripts/thunderdb/postgres.sql b/backend/dbscripts/thunderdb/postgres.sql index 74ba68d41..be767f05d 100644 --- a/backend/dbscripts/thunderdb/postgres.sql +++ b/backend/dbscripts/thunderdb/postgres.sql @@ -55,6 +55,50 @@ CREATE TABLE GROUP_MEMBER_REFERENCE ( FOREIGN KEY (GROUP_ID) REFERENCES "GROUP" (GROUP_ID) ON DELETE CASCADE ); +-- Table to store Roles +CREATE TABLE "ROLE" ( + ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ROLE_ID VARCHAR(36) UNIQUE NOT NULL, + OU_ID VARCHAR(36) NOT NULL, + NAME VARCHAR(50) NOT NULL, + DESCRIPTION VARCHAR(255), + CREATED_AT TIMESTAMPTZ DEFAULT NOW(), + UPDATED_AT TIMESTAMPTZ DEFAULT NOW(), + CONSTRAINT unique_role_ou_name UNIQUE (OU_ID, NAME) +); + +-- Table to store Role permissions +CREATE TABLE ROLE_PERMISSION ( + ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ROLE_ID VARCHAR(36) NOT NULL, + PERMISSION VARCHAR(100) NOT NULL, + CREATED_AT TIMESTAMPTZ DEFAULT NOW(), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_permission UNIQUE (ROLE_ID, PERMISSION) +); + +-- Table to store Role assignments (to users and groups) +CREATE TABLE ROLE_ASSIGNMENT ( + ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ROLE_ID VARCHAR(36) NOT NULL, + ASSIGNEE_TYPE VARCHAR(5) NOT NULL CHECK (ASSIGNEE_TYPE IN ('user', 'group')), + ASSIGNEE_ID VARCHAR(36) NOT NULL, + CREATED_AT TIMESTAMPTZ DEFAULT NOW(), + UPDATED_AT TIMESTAMPTZ DEFAULT NOW(), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_assignment UNIQUE (ROLE_ID, ASSIGNEE_TYPE, ASSIGNEE_ID) +); + +-- Indexes for authorization queries + +-- Index for finding all roles assigned to a specific assignee +CREATE INDEX idx_role_assignment_assignee +ON ROLE_ASSIGNMENT (ASSIGNEE_ID, ASSIGNEE_TYPE); + +-- Index for finding all permissions for a specific role +CREATE INDEX idx_role_permission_role +ON ROLE_PERMISSION (ROLE_ID); + -- Table to store basic service provider (app) details. CREATE TABLE SP_APP ( ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, diff --git a/backend/dbscripts/thunderdb/sqlite.sql b/backend/dbscripts/thunderdb/sqlite.sql index 17aaec3be..51c83fcdb 100644 --- a/backend/dbscripts/thunderdb/sqlite.sql +++ b/backend/dbscripts/thunderdb/sqlite.sql @@ -55,6 +55,50 @@ CREATE TABLE GROUP_MEMBER_REFERENCE ( FOREIGN KEY (GROUP_ID) REFERENCES "GROUP" (GROUP_ID) ON DELETE CASCADE ); +-- Table to store Roles +CREATE TABLE "ROLE" ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + ROLE_ID VARCHAR(36) UNIQUE NOT NULL, + OU_ID VARCHAR(36) NOT NULL, + NAME VARCHAR(50) NOT NULL, + DESCRIPTION VARCHAR(255), + CREATED_AT TEXT DEFAULT (datetime('now')), + UPDATED_AT TEXT DEFAULT (datetime('now')), + CONSTRAINT unique_role_ou_name UNIQUE (OU_ID, NAME) +); + +-- Table to store Role permissions +CREATE TABLE ROLE_PERMISSION ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + ROLE_ID VARCHAR(36) NOT NULL, + PERMISSION VARCHAR(100) NOT NULL, + CREATED_AT TEXT DEFAULT (datetime('now')), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_permission UNIQUE (ROLE_ID, PERMISSION) +); + +-- Table to store Role assignments (to users and groups) +CREATE TABLE ROLE_ASSIGNMENT ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + ROLE_ID VARCHAR(36) NOT NULL, + ASSIGNEE_TYPE VARCHAR(5) NOT NULL CHECK (ASSIGNEE_TYPE IN ('user', 'group')), + ASSIGNEE_ID VARCHAR(36) NOT NULL, + CREATED_AT TEXT DEFAULT (datetime('now')), + UPDATED_AT TEXT DEFAULT (datetime('now')), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_assignment UNIQUE (ROLE_ID, ASSIGNEE_TYPE, ASSIGNEE_ID) +); + +-- Indexes for authorization queries + +-- Index for finding all roles assigned to a specific assignee +CREATE INDEX idx_role_assignment_assignee +ON ROLE_ASSIGNMENT (ASSIGNEE_ID, ASSIGNEE_TYPE); + +-- Index for finding all permissions for a specific role +CREATE INDEX idx_role_permission_role +ON ROLE_PERMISSION (ROLE_ID); + -- Table to store basic service provider (app) details. CREATE TABLE SP_APP ( ID INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/backend/internal/group/service.go b/backend/internal/group/service.go index 41f19b6b8..73d119aa1 100644 --- a/backend/internal/group/service.go +++ b/backend/internal/group/service.go @@ -44,6 +44,7 @@ type GroupServiceInterface interface { UpdateGroup(groupID string, request UpdateGroupRequest) (*Group, *serviceerror.ServiceError) DeleteGroup(groupID string) *serviceerror.ServiceError GetGroupMembers(groupID string, limit, offset int) (*MemberListResponse, *serviceerror.ServiceError) + ValidateGroupIDs(groupIDs []string) *serviceerror.ServiceError } // groupService is the default implementation of the GroupServiceInterface. @@ -183,7 +184,7 @@ func (gs *groupService) CreateGroup(request CreateGroupRequest) (*Group, *servic return nil, err } - if err := gs.validateGroupIDs(groupIDs); err != nil { + if err := gs.ValidateGroupIDs(groupIDs); err != nil { return nil, err } @@ -318,7 +319,7 @@ func (gs *groupService) UpdateGroup( return nil, err } - if err := gs.validateGroupIDs(groupIDs); err != nil { + if err := gs.ValidateGroupIDs(groupIDs); err != nil { return nil, err } @@ -514,8 +515,8 @@ func (gs *groupService) validateUserIDs(userIDs []string) *serviceerror.ServiceE return nil } -// validateGroupIDs validates that all provided group IDs exist. -func (gs *groupService) validateGroupIDs(groupIDs []string) *serviceerror.ServiceError { +// ValidateGroupIDs validates that all provided group IDs exist. +func (gs *groupService) ValidateGroupIDs(groupIDs []string) *serviceerror.ServiceError { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) invalidGroupIDs, err := gs.groupStore.ValidateGroupIDs(groupIDs) diff --git a/backend/internal/role/RoleServiceInterface_mock_test.go b/backend/internal/role/RoleServiceInterface_mock_test.go new file mode 100644 index 000000000..c7d8d88da --- /dev/null +++ b/backend/internal/role/RoleServiceInterface_mock_test.go @@ -0,0 +1,634 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package role + +import ( + "github.com/asgardeo/thunder/internal/system/error/serviceerror" + mock "github.com/stretchr/testify/mock" +) + +// NewRoleServiceInterfaceMock creates a new instance of RoleServiceInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRoleServiceInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *RoleServiceInterfaceMock { + mock := &RoleServiceInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// RoleServiceInterfaceMock is an autogenerated mock type for the RoleServiceInterface type +type RoleServiceInterfaceMock struct { + mock.Mock +} + +type RoleServiceInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *RoleServiceInterfaceMock) EXPECT() *RoleServiceInterfaceMock_Expecter { + return &RoleServiceInterfaceMock_Expecter{mock: &_m.Mock} +} + +// AddAssignments provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) AddAssignments(id string, assignments []RoleAssignment) *serviceerror.ServiceError { + ret := _mock.Called(id, assignments) + + if len(ret) == 0 { + panic("no return value specified for AddAssignments") + } + + var r0 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, []RoleAssignment) *serviceerror.ServiceError); ok { + r0 = returnFunc(id, assignments) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceerror.ServiceError) + } + } + return r0 +} + +// RoleServiceInterfaceMock_AddAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddAssignments' +type RoleServiceInterfaceMock_AddAssignments_Call struct { + *mock.Call +} + +// AddAssignments is a helper method to define mock.On call +// - id string +// - assignments []RoleAssignment +func (_e *RoleServiceInterfaceMock_Expecter) AddAssignments(id interface{}, assignments interface{}) *RoleServiceInterfaceMock_AddAssignments_Call { + return &RoleServiceInterfaceMock_AddAssignments_Call{Call: _e.mock.On("AddAssignments", id, assignments)} +} + +func (_c *RoleServiceInterfaceMock_AddAssignments_Call) Run(run func(id string, assignments []RoleAssignment)) *RoleServiceInterfaceMock_AddAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []RoleAssignment + if args[1] != nil { + arg1 = args[1].([]RoleAssignment) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_AddAssignments_Call) Return(serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_AddAssignments_Call { + _c.Call.Return(serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_AddAssignments_Call) RunAndReturn(run func(id string, assignments []RoleAssignment) *serviceerror.ServiceError) *RoleServiceInterfaceMock_AddAssignments_Call { + _c.Call.Return(run) + return _c +} + +// CreateRole provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) CreateRole(role RoleCreationDetail) (*RoleWithPermissionsAndAssignments, *serviceerror.ServiceError) { + ret := _mock.Called(role) + + if len(ret) == 0 { + panic("no return value specified for CreateRole") + } + + var r0 *RoleWithPermissionsAndAssignments + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(RoleCreationDetail) (*RoleWithPermissionsAndAssignments, *serviceerror.ServiceError)); ok { + return returnFunc(role) + } + if returnFunc, ok := ret.Get(0).(func(RoleCreationDetail) *RoleWithPermissionsAndAssignments); ok { + r0 = returnFunc(role) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*RoleWithPermissionsAndAssignments) + } + } + if returnFunc, ok := ret.Get(1).(func(RoleCreationDetail) *serviceerror.ServiceError); ok { + r1 = returnFunc(role) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// RoleServiceInterfaceMock_CreateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRole' +type RoleServiceInterfaceMock_CreateRole_Call struct { + *mock.Call +} + +// CreateRole is a helper method to define mock.On call +// - role RoleCreationDetail +func (_e *RoleServiceInterfaceMock_Expecter) CreateRole(role interface{}) *RoleServiceInterfaceMock_CreateRole_Call { + return &RoleServiceInterfaceMock_CreateRole_Call{Call: _e.mock.On("CreateRole", role)} +} + +func (_c *RoleServiceInterfaceMock_CreateRole_Call) Run(run func(role RoleCreationDetail)) *RoleServiceInterfaceMock_CreateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 RoleCreationDetail + if args[0] != nil { + arg0 = args[0].(RoleCreationDetail) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_CreateRole_Call) Return(roleWithPermissionsAndAssignments *RoleWithPermissionsAndAssignments, serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_CreateRole_Call { + _c.Call.Return(roleWithPermissionsAndAssignments, serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_CreateRole_Call) RunAndReturn(run func(role RoleCreationDetail) (*RoleWithPermissionsAndAssignments, *serviceerror.ServiceError)) *RoleServiceInterfaceMock_CreateRole_Call { + _c.Call.Return(run) + return _c +} + +// DeleteRole provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) DeleteRole(id string) *serviceerror.ServiceError { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for DeleteRole") + } + + var r0 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) *serviceerror.ServiceError); ok { + r0 = returnFunc(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceerror.ServiceError) + } + } + return r0 +} + +// RoleServiceInterfaceMock_DeleteRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteRole' +type RoleServiceInterfaceMock_DeleteRole_Call struct { + *mock.Call +} + +// DeleteRole is a helper method to define mock.On call +// - id string +func (_e *RoleServiceInterfaceMock_Expecter) DeleteRole(id interface{}) *RoleServiceInterfaceMock_DeleteRole_Call { + return &RoleServiceInterfaceMock_DeleteRole_Call{Call: _e.mock.On("DeleteRole", id)} +} + +func (_c *RoleServiceInterfaceMock_DeleteRole_Call) Run(run func(id string)) *RoleServiceInterfaceMock_DeleteRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_DeleteRole_Call) Return(serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_DeleteRole_Call { + _c.Call.Return(serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_DeleteRole_Call) RunAndReturn(run func(id string) *serviceerror.ServiceError) *RoleServiceInterfaceMock_DeleteRole_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthorizedPermissions provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) GetAuthorizedPermissions(userID string, groups []string, requestedPermissions []string) ([]string, *serviceerror.ServiceError) { + ret := _mock.Called(userID, groups, requestedPermissions) + + if len(ret) == 0 { + panic("no return value specified for GetAuthorizedPermissions") + } + + var r0 []string + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, []string, []string) ([]string, *serviceerror.ServiceError)); ok { + return returnFunc(userID, groups, requestedPermissions) + } + if returnFunc, ok := ret.Get(0).(func(string, []string, []string) []string); ok { + r0 = returnFunc(userID, groups, requestedPermissions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(string, []string, []string) *serviceerror.ServiceError); ok { + r1 = returnFunc(userID, groups, requestedPermissions) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// RoleServiceInterfaceMock_GetAuthorizedPermissions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthorizedPermissions' +type RoleServiceInterfaceMock_GetAuthorizedPermissions_Call struct { + *mock.Call +} + +// GetAuthorizedPermissions is a helper method to define mock.On call +// - userID string +// - groups []string +// - requestedPermissions []string +func (_e *RoleServiceInterfaceMock_Expecter) GetAuthorizedPermissions(userID interface{}, groups interface{}, requestedPermissions interface{}) *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call { + return &RoleServiceInterfaceMock_GetAuthorizedPermissions_Call{Call: _e.mock.On("GetAuthorizedPermissions", userID, groups, requestedPermissions)} +} + +func (_c *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call) Run(run func(userID string, groups []string, requestedPermissions []string)) *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call) Return(strings []string, serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call { + _c.Call.Return(strings, serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call) RunAndReturn(run func(userID string, groups []string, requestedPermissions []string) ([]string, *serviceerror.ServiceError)) *RoleServiceInterfaceMock_GetAuthorizedPermissions_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleAssignments provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) GetRoleAssignments(id string, limit int, offset int, includeDisplay bool) (*AssignmentList, *serviceerror.ServiceError) { + ret := _mock.Called(id, limit, offset, includeDisplay) + + if len(ret) == 0 { + panic("no return value specified for GetRoleAssignments") + } + + var r0 *AssignmentList + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, int, int, bool) (*AssignmentList, *serviceerror.ServiceError)); ok { + return returnFunc(id, limit, offset, includeDisplay) + } + if returnFunc, ok := ret.Get(0).(func(string, int, int, bool) *AssignmentList); ok { + r0 = returnFunc(id, limit, offset, includeDisplay) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*AssignmentList) + } + } + if returnFunc, ok := ret.Get(1).(func(string, int, int, bool) *serviceerror.ServiceError); ok { + r1 = returnFunc(id, limit, offset, includeDisplay) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// RoleServiceInterfaceMock_GetRoleAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleAssignments' +type RoleServiceInterfaceMock_GetRoleAssignments_Call struct { + *mock.Call +} + +// GetRoleAssignments is a helper method to define mock.On call +// - id string +// - limit int +// - offset int +// - includeDisplay bool +func (_e *RoleServiceInterfaceMock_Expecter) GetRoleAssignments(id interface{}, limit interface{}, offset interface{}, includeDisplay interface{}) *RoleServiceInterfaceMock_GetRoleAssignments_Call { + return &RoleServiceInterfaceMock_GetRoleAssignments_Call{Call: _e.mock.On("GetRoleAssignments", id, limit, offset, includeDisplay)} +} + +func (_c *RoleServiceInterfaceMock_GetRoleAssignments_Call) Run(run func(id string, limit int, offset int, includeDisplay bool)) *RoleServiceInterfaceMock_GetRoleAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + var arg3 bool + if args[3] != nil { + arg3 = args[3].(bool) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetRoleAssignments_Call) Return(assignmentList *AssignmentList, serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_GetRoleAssignments_Call { + _c.Call.Return(assignmentList, serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetRoleAssignments_Call) RunAndReturn(run func(id string, limit int, offset int, includeDisplay bool) (*AssignmentList, *serviceerror.ServiceError)) *RoleServiceInterfaceMock_GetRoleAssignments_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleList provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) GetRoleList(limit int, offset int) (*RoleList, *serviceerror.ServiceError) { + ret := _mock.Called(limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetRoleList") + } + + var r0 *RoleList + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(int, int) (*RoleList, *serviceerror.ServiceError)); ok { + return returnFunc(limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(int, int) *RoleList); ok { + r0 = returnFunc(limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*RoleList) + } + } + if returnFunc, ok := ret.Get(1).(func(int, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(limit, offset) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// RoleServiceInterfaceMock_GetRoleList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleList' +type RoleServiceInterfaceMock_GetRoleList_Call struct { + *mock.Call +} + +// GetRoleList is a helper method to define mock.On call +// - limit int +// - offset int +func (_e *RoleServiceInterfaceMock_Expecter) GetRoleList(limit interface{}, offset interface{}) *RoleServiceInterfaceMock_GetRoleList_Call { + return &RoleServiceInterfaceMock_GetRoleList_Call{Call: _e.mock.On("GetRoleList", limit, offset)} +} + +func (_c *RoleServiceInterfaceMock_GetRoleList_Call) Run(run func(limit int, offset int)) *RoleServiceInterfaceMock_GetRoleList_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetRoleList_Call) Return(roleList *RoleList, serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_GetRoleList_Call { + _c.Call.Return(roleList, serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetRoleList_Call) RunAndReturn(run func(limit int, offset int) (*RoleList, *serviceerror.ServiceError)) *RoleServiceInterfaceMock_GetRoleList_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleWithPermissions provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) GetRoleWithPermissions(id string) (*RoleWithPermissions, *serviceerror.ServiceError) { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for GetRoleWithPermissions") + } + + var r0 *RoleWithPermissions + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) (*RoleWithPermissions, *serviceerror.ServiceError)); ok { + return returnFunc(id) + } + if returnFunc, ok := ret.Get(0).(func(string) *RoleWithPermissions); ok { + r0 = returnFunc(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*RoleWithPermissions) + } + } + if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { + r1 = returnFunc(id) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// RoleServiceInterfaceMock_GetRoleWithPermissions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleWithPermissions' +type RoleServiceInterfaceMock_GetRoleWithPermissions_Call struct { + *mock.Call +} + +// GetRoleWithPermissions is a helper method to define mock.On call +// - id string +func (_e *RoleServiceInterfaceMock_Expecter) GetRoleWithPermissions(id interface{}) *RoleServiceInterfaceMock_GetRoleWithPermissions_Call { + return &RoleServiceInterfaceMock_GetRoleWithPermissions_Call{Call: _e.mock.On("GetRoleWithPermissions", id)} +} + +func (_c *RoleServiceInterfaceMock_GetRoleWithPermissions_Call) Run(run func(id string)) *RoleServiceInterfaceMock_GetRoleWithPermissions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetRoleWithPermissions_Call) Return(roleWithPermissions *RoleWithPermissions, serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_GetRoleWithPermissions_Call { + _c.Call.Return(roleWithPermissions, serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_GetRoleWithPermissions_Call) RunAndReturn(run func(id string) (*RoleWithPermissions, *serviceerror.ServiceError)) *RoleServiceInterfaceMock_GetRoleWithPermissions_Call { + _c.Call.Return(run) + return _c +} + +// RemoveAssignments provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) RemoveAssignments(id string, assignments []RoleAssignment) *serviceerror.ServiceError { + ret := _mock.Called(id, assignments) + + if len(ret) == 0 { + panic("no return value specified for RemoveAssignments") + } + + var r0 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, []RoleAssignment) *serviceerror.ServiceError); ok { + r0 = returnFunc(id, assignments) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceerror.ServiceError) + } + } + return r0 +} + +// RoleServiceInterfaceMock_RemoveAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveAssignments' +type RoleServiceInterfaceMock_RemoveAssignments_Call struct { + *mock.Call +} + +// RemoveAssignments is a helper method to define mock.On call +// - id string +// - assignments []RoleAssignment +func (_e *RoleServiceInterfaceMock_Expecter) RemoveAssignments(id interface{}, assignments interface{}) *RoleServiceInterfaceMock_RemoveAssignments_Call { + return &RoleServiceInterfaceMock_RemoveAssignments_Call{Call: _e.mock.On("RemoveAssignments", id, assignments)} +} + +func (_c *RoleServiceInterfaceMock_RemoveAssignments_Call) Run(run func(id string, assignments []RoleAssignment)) *RoleServiceInterfaceMock_RemoveAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []RoleAssignment + if args[1] != nil { + arg1 = args[1].([]RoleAssignment) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_RemoveAssignments_Call) Return(serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_RemoveAssignments_Call { + _c.Call.Return(serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_RemoveAssignments_Call) RunAndReturn(run func(id string, assignments []RoleAssignment) *serviceerror.ServiceError) *RoleServiceInterfaceMock_RemoveAssignments_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRoleWithPermissions provides a mock function for the type RoleServiceInterfaceMock +func (_mock *RoleServiceInterfaceMock) UpdateRoleWithPermissions(id string, role RoleUpdateDetail) (*RoleWithPermissions, *serviceerror.ServiceError) { + ret := _mock.Called(id, role) + + if len(ret) == 0 { + panic("no return value specified for UpdateRoleWithPermissions") + } + + var r0 *RoleWithPermissions + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, RoleUpdateDetail) (*RoleWithPermissions, *serviceerror.ServiceError)); ok { + return returnFunc(id, role) + } + if returnFunc, ok := ret.Get(0).(func(string, RoleUpdateDetail) *RoleWithPermissions); ok { + r0 = returnFunc(id, role) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*RoleWithPermissions) + } + } + if returnFunc, ok := ret.Get(1).(func(string, RoleUpdateDetail) *serviceerror.ServiceError); ok { + r1 = returnFunc(id, role) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRoleWithPermissions' +type RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call struct { + *mock.Call +} + +// UpdateRoleWithPermissions is a helper method to define mock.On call +// - id string +// - role RoleUpdateDetail +func (_e *RoleServiceInterfaceMock_Expecter) UpdateRoleWithPermissions(id interface{}, role interface{}) *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call { + return &RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call{Call: _e.mock.On("UpdateRoleWithPermissions", id, role)} +} + +func (_c *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call) Run(run func(id string, role RoleUpdateDetail)) *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 RoleUpdateDetail + if args[1] != nil { + arg1 = args[1].(RoleUpdateDetail) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call) Return(roleWithPermissions *RoleWithPermissions, serviceError *serviceerror.ServiceError) *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call { + _c.Call.Return(roleWithPermissions, serviceError) + return _c +} + +func (_c *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call) RunAndReturn(run func(id string, role RoleUpdateDetail) (*RoleWithPermissions, *serviceerror.ServiceError)) *RoleServiceInterfaceMock_UpdateRoleWithPermissions_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/internal/role/errorconstants.go b/backend/internal/role/errorconstants.go new file mode 100644 index 000000000..32000d364 --- /dev/null +++ b/backend/internal/role/errorconstants.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "errors" + + "github.com/asgardeo/thunder/internal/system/error/serviceerror" +) + +// Client errors for role management operations. +var ( + // ErrorInvalidRequestFormat is the error returned when the request format is invalid. + ErrorInvalidRequestFormat = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1001", + Error: "Invalid request format", + ErrorDescription: "The request body is malformed or contains invalid data", + } + // ErrorMissingRoleID is the error returned when role ID is missing. + ErrorMissingRoleID = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1002", + Error: "Invalid request format", + ErrorDescription: "Role ID is required", + } + // ErrorRoleNotFound is the error returned when a role is not found. + ErrorRoleNotFound = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1003", + Error: "Role not found", + ErrorDescription: "The role with the specified id does not exist", + } + // ErrorRoleNameConflict is the error returned when a role name already exists in the organization unit. + ErrorRoleNameConflict = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1004", + Error: "Role name conflict", + ErrorDescription: "A role with the same name exists under the same organization unit", + } + // ErrorOrganizationUnitNotFound is the error returned when organization unit is not found. + ErrorOrganizationUnitNotFound = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1005", + Error: "Organization unit not found", + ErrorDescription: "Organization unit not found", + } + // ErrorCannotDeleteRole is the error returned when role cannot be deleted. + ErrorCannotDeleteRole = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1006", + Error: "Cannot delete role", + ErrorDescription: "Cannot delete role that is currently assigned to users or groups", + } + // ErrorInvalidAssignmentID is the error returned when assignment ID is invalid. + ErrorInvalidAssignmentID = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1007", + Error: "Invalid assignment ID", + ErrorDescription: "One or more assignment IDs in the request do not exist", + } + // ErrorInvalidLimit is the error returned when limit parameter is invalid. + ErrorInvalidLimit = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1008", + Error: "Invalid limit parameter", + ErrorDescription: "The limit parameter must be a positive integer", + } + // ErrorInvalidOffset is the error returned when offset parameter is invalid. + ErrorInvalidOffset = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1009", + Error: "Invalid offset parameter", + ErrorDescription: "The offset parameter must be a non-negative integer", + } + // ErrorEmptyAssignments is the error returned when assignments list is empty. + ErrorEmptyAssignments = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1010", + Error: "Empty assignments list", + ErrorDescription: "At least one assignment must be provided", + } + // ErrorMissingUserOrGroups is the error returned when both user ID and groups are missing. + ErrorMissingUserOrGroups = serviceerror.ServiceError{ + Type: serviceerror.ClientErrorType, + Code: "ROL-1011", + Error: "Invalid request format", + ErrorDescription: "Either userId or groups must be provided for authorization check", + } +) + +// Server errors for role management operations. +var ( + // ErrorInternalServerError is the error returned when an internal server error occurs. + ErrorInternalServerError = serviceerror.ServiceError{ + Type: serviceerror.ServerErrorType, + Code: "ROL-5000", + Error: "Internal server error", + ErrorDescription: "An unexpected error occurred while processing the request", + } +) + +// Internal error constants for role management operations. +var ( + // ErrRoleNotFound is returned when the role is not found in the system. + ErrRoleNotFound = errors.New("role not found") + + // ErrRoleHasAssignments is returned when attempting to delete a role that has active assignments. + ErrRoleHasAssignments = errors.New("role has active assignments") + + // ErrRoleNameConflict is returned when a role with the same name already exists in the organization unit. + ErrRoleNameConflict = errors.New("role name conflict") +) diff --git a/backend/internal/role/handler.go b/backend/internal/role/handler.go new file mode 100644 index 000000000..1001e987f --- /dev/null +++ b/backend/internal/role/handler.go @@ -0,0 +1,513 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + + serverconst "github.com/asgardeo/thunder/internal/system/constants" + "github.com/asgardeo/thunder/internal/system/error/apierror" + "github.com/asgardeo/thunder/internal/system/error/serviceerror" + "github.com/asgardeo/thunder/internal/system/log" + sysutils "github.com/asgardeo/thunder/internal/system/utils" +) + +const handlerLoggerComponentName = "RoleHandler" + +// roleHandler is the handler for role management operations. +type roleHandler struct { + roleService RoleServiceInterface +} + +// newRoleHandler creates a new instance of roleHandler +func newRoleHandler(roleService RoleServiceInterface) *roleHandler { + return &roleHandler{ + roleService: roleService, + } +} + +// HandleRoleListRequest handles the list roles request. +func (rh *roleHandler) HandleRoleListRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + limit, offset, svcErr := parsePaginationParams(r.URL.Query()) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + roleList, svcErr := rh.roleService.GetRoleList(limit, offset) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + // Convert service response to HTTP response + roles := make([]RoleSummaryResponse, 0, len(roleList.Roles)) + for _, role := range roleList.Roles { + roles = append(roles, RoleSummaryResponse(role)) + } + + roleListResponse := &RoleListResponse{ + TotalResults: roleList.TotalResults, + StartIndex: roleList.StartIndex, + Count: roleList.Count, + Roles: roles, + Links: toHTTPLinks(roleList.Links), + } + + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + + isErr := writeToResponse(w, roleListResponse, logger) + if isErr { + return + } + + logger.Debug("Successfully listed roles with pagination", + log.Int("limit", limit), log.Int("offset", offset), + log.Int("totalResults", roleListResponse.TotalResults), + log.Int("count", roleListResponse.Count)) +} + +// HandleRolePostRequest handles the create role request. +func (rh *roleHandler) HandleRolePostRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + createRequest, err := sysutils.DecodeJSONBody[CreateRoleRequest](r) + if err != nil { + handleError(w, logger, &ErrorInvalidRequestFormat) + return + } + + sanitizedRequest := rh.sanitizeCreateRoleRequest(createRequest) + + // Convert HTTP request to service request + serviceRequest := rh.toRoleCreationDetail(sanitizedRequest) + + serviceRole, svcErr := rh.roleService.CreateRole(serviceRequest) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + // Convert service response to HTTP response + createdRole := rh.toHTTPCreateRoleResponse(serviceRole) + + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(http.StatusCreated) + + isErr := writeToResponse(w, createdRole, logger) + if isErr { + return + } + + logger.Debug("Successfully created role", log.String("roleId", createdRole.ID)) +} + +// HandleRoleGetRequest handles the get role by id request. +func (rh *roleHandler) HandleRoleGetRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + id := r.PathValue("id") + serviceRole, svcErr := rh.roleService.GetRoleWithPermissions(id) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + // Convert service response to HTTP response + role := rh.toHTTPRoleResponse(serviceRole) + + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + + isErr := writeToResponse(w, role, logger) + if isErr { + return + } + + logger.Debug("Successfully retrieved role", log.String("role id", id)) +} + +// HandleRolePutRequest handles the update role request. +func (rh *roleHandler) HandleRolePutRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + id := r.PathValue("id") + updateRequest, err := sysutils.DecodeJSONBody[UpdateRoleRequest](r) + if err != nil { + handleError(w, logger, &ErrorInvalidRequestFormat) + return + } + + sanitizedRequest := rh.sanitizeUpdateRoleRequest(updateRequest) + + // Convert HTTP request to service request + serviceRequest := RoleUpdateDetail(sanitizedRequest) + + serviceRole, svcErr := rh.roleService.UpdateRoleWithPermissions(id, serviceRequest) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + // Convert service response to HTTP response + role := rh.toHTTPRoleResponse(serviceRole) + + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + + isErr := writeToResponse(w, role, logger) + if isErr { + return + } + + logger.Debug("Successfully updated role", log.String("role id", id)) +} + +// HandleRoleDeleteRequest handles the delete role request. +func (rh *roleHandler) HandleRoleDeleteRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + id := r.PathValue("id") + svcErr := rh.roleService.DeleteRole(id) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + w.WriteHeader(http.StatusNoContent) + logger.Debug("Successfully deleted role", log.String("role id", id)) +} + +// HandleRoleAssignmentsGetRequest handles the get role assignments request. +func (rh *roleHandler) HandleRoleAssignmentsGetRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + id := r.PathValue("id") + limit, offset, svcErr := parsePaginationParams(r.URL.Query()) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + // Parse include parameter to check if display names should be included + includeDisplay := r.URL.Query().Get("include") == "display" + + serviceResponse, svcErr := rh.roleService.GetRoleAssignments(id, limit, offset, includeDisplay) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + // Convert service response to HTTP response + httpAssignments := make([]AssignmentResponse, len(serviceResponse.Assignments)) + for i, sa := range serviceResponse.Assignments { + httpAssignments[i] = AssignmentResponse(sa) + } + + assignmentListResponse := &AssignmentListResponse{ + TotalResults: serviceResponse.TotalResults, + StartIndex: serviceResponse.StartIndex, + Count: serviceResponse.Count, + Assignments: httpAssignments, + Links: toHTTPLinks(serviceResponse.Links), + } + + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + + isErr := writeToResponse(w, assignmentListResponse, logger) + if isErr { + return + } + + logger.Debug("Successfully retrieved role assignments", log.String("role id", id), + log.Int("limit", limit), log.Int("offset", offset), + log.Bool("includeDisplay", includeDisplay), + log.Int("totalResults", assignmentListResponse.TotalResults), + log.Int("count", assignmentListResponse.Count)) +} + +// HandleRoleAddAssignmentsRequest handles the add assignments to role request. +func (rh *roleHandler) HandleRoleAddAssignmentsRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + id := r.PathValue("id") + assignmentsRequest, err := sysutils.DecodeJSONBody[AssignmentsRequest](r) + if err != nil { + handleError(w, logger, &ErrorInvalidRequestFormat) + return + } + + sanitizedRequest := rh.sanitizeAssignmentsRequest(assignmentsRequest) + + // Convert HTTP request to service request + serviceRequest := rh.toRoleAssignments(sanitizedRequest) + + svcErr := rh.roleService.AddAssignments(id, serviceRequest) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + w.WriteHeader(http.StatusNoContent) + logger.Debug("Successfully added assignments to role", log.String("role id", id)) +} + +// HandleRoleRemoveAssignmentsRequest handles the remove assignments from role request. +func (rh *roleHandler) HandleRoleRemoveAssignmentsRequest(w http.ResponseWriter, r *http.Request) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, handlerLoggerComponentName)) + + id := r.PathValue("id") + assignmentsRequest, err := sysutils.DecodeJSONBody[AssignmentsRequest](r) + if err != nil { + handleError(w, logger, &ErrorInvalidRequestFormat) + return + } + + sanitizedRequest := rh.sanitizeAssignmentsRequest(assignmentsRequest) + + // Convert HTTP request to service request + serviceRequest := rh.toRoleAssignments(sanitizedRequest) + + svcErr := rh.roleService.RemoveAssignments(id, serviceRequest) + if svcErr != nil { + handleError(w, logger, svcErr) + return + } + + w.WriteHeader(http.StatusNoContent) + logger.Debug("Successfully removed assignments from role", log.String("role id", id)) +} + +// writeToResponse encodes the response as JSON and writes it to the ResponseWriter. +func writeToResponse(w http.ResponseWriter, response any, logger *log.Logger) bool { + if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Error("Error encoding response", log.Error(err)) + handleEncodingError(w) + return true + } + return false +} + +// handleError handles service errors and returns appropriate HTTP responses. +func handleError(w http.ResponseWriter, logger *log.Logger, + svcErr *serviceerror.ServiceError) { + statusCode := http.StatusInternalServerError + if svcErr.Type == serviceerror.ClientErrorType { + switch svcErr.Code { + case ErrorRoleNotFound.Code: + statusCode = http.StatusNotFound + case ErrorRoleNameConflict.Code: + statusCode = http.StatusConflict + case ErrorOrganizationUnitNotFound.Code, ErrorCannotDeleteRole.Code, + ErrorInvalidRequestFormat.Code, ErrorMissingRoleID.Code, + ErrorInvalidLimit.Code, ErrorInvalidOffset.Code, + ErrorEmptyAssignments.Code, + ErrorInvalidAssignmentID.Code: + statusCode = http.StatusBadRequest + default: + statusCode = http.StatusBadRequest + } + } + + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(statusCode) + + errResp := apierror.ErrorResponse{ + Code: svcErr.Code, + Message: svcErr.Error, + Description: svcErr.ErrorDescription, + } + + if err := json.NewEncoder(w).Encode(errResp); err != nil { + logger.Error("Error encoding error response", log.Error(err)) + handleEncodingError(w) + return + } +} + +// handleEncodingError handles errors that occur during response encoding. +func handleEncodingError(w http.ResponseWriter) { + w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintln(w, serviceerror.ErrorEncodingError) +} + +// sanitizeCreateRoleRequest sanitizes the create role request input. +func (rh *roleHandler) sanitizeCreateRoleRequest(request *CreateRoleRequest) CreateRoleRequest { + sanitized := CreateRoleRequest{ + Name: sysutils.SanitizeString(request.Name), + Description: sysutils.SanitizeString(request.Description), + OrganizationUnitID: sysutils.SanitizeString(request.OrganizationUnitID), + } + + if request.Permissions != nil { + sanitized.Permissions = make([]string, len(request.Permissions)) + for i, permission := range request.Permissions { + sanitized.Permissions[i] = sysutils.SanitizeString(permission) + } + } + + if request.Assignments != nil { + sanitized.Assignments = make([]AssignmentRequest, len(request.Assignments)) + for i, assignment := range request.Assignments { + sanitized.Assignments[i] = AssignmentRequest{ + ID: sysutils.SanitizeString(assignment.ID), + Type: assignment.Type, + } + } + } + + return sanitized +} + +// sanitizeUpdateRoleRequest sanitizes the update role request input. +func (rh *roleHandler) sanitizeUpdateRoleRequest(request *UpdateRoleRequest) UpdateRoleRequest { + sanitized := UpdateRoleRequest{ + Name: sysutils.SanitizeString(request.Name), + Description: sysutils.SanitizeString(request.Description), + OrganizationUnitID: sysutils.SanitizeString(request.OrganizationUnitID), + } + + if request.Permissions != nil { + sanitized.Permissions = make([]string, len(request.Permissions)) + for i, permission := range request.Permissions { + sanitized.Permissions[i] = sysutils.SanitizeString(permission) + } + } + + return sanitized +} + +// sanitizeAssignmentsRequest sanitizes the assignments request input. +func (rh *roleHandler) sanitizeAssignmentsRequest(request *AssignmentsRequest) AssignmentsRequest { + sanitized := AssignmentsRequest{} + + if request.Assignments != nil { + sanitized.Assignments = make([]AssignmentRequest, len(request.Assignments)) + for i, assignment := range request.Assignments { + sanitized.Assignments[i] = AssignmentRequest{ + ID: sysutils.SanitizeString(assignment.ID), + Type: assignment.Type, + } + } + } + + return sanitized +} + +// parsePaginationParams parses limit and offset query parameters from the request. +func parsePaginationParams(query url.Values) (int, int, *serviceerror.ServiceError) { + limit := 0 + offset := 0 + + if limitStr := query.Get("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err != nil { + return 0, 0, &ErrorInvalidLimit + } else { + limit = parsedLimit + } + } + + if offsetStr := query.Get("offset"); offsetStr != "" { + if parsedOffset, err := strconv.Atoi(offsetStr); err != nil { + return 0, 0, &ErrorInvalidOffset + } else { + offset = parsedOffset + } + } + + if limit == 0 { + limit = serverconst.DefaultPageSize + } + + return limit, offset, nil +} + +// toRoleCreationDetail converts HTTP CreateRoleRequest to service layer RoleCreationDetail. +func (rh *roleHandler) toRoleCreationDetail(req CreateRoleRequest) RoleCreationDetail { + serviceAssignments := make([]RoleAssignment, len(req.Assignments)) + for i, a := range req.Assignments { + serviceAssignments[i] = RoleAssignment(a) + } + + return RoleCreationDetail{ + Name: req.Name, + Description: req.Description, + OrganizationUnitID: req.OrganizationUnitID, + Permissions: req.Permissions, + Assignments: serviceAssignments, + } +} + +// toHTTPRole converts service layer RoleWithPermissions to HTTP Role. +func (rh *roleHandler) toHTTPRoleResponse(role *RoleWithPermissions) *RoleResponse { + return &RoleResponse{ + ID: role.ID, + Name: role.Name, + Description: role.Description, + OrganizationUnitID: role.OrganizationUnitID, + Permissions: role.Permissions, + } +} + +// toHTTPCreateRoleResponse converts service layer RoleDetails to HTTP CreateRoleResponse. +func (rh *roleHandler) toHTTPCreateRoleResponse(role *RoleWithPermissionsAndAssignments) *CreateRoleResponse { + httpAssignments := make([]AssignmentResponse, len(role.Assignments)) + for i, sa := range role.Assignments { + httpAssignments[i] = AssignmentResponse{ + ID: sa.ID, + Type: sa.Type, + } + } + + return &CreateRoleResponse{ + ID: role.ID, + Name: role.Name, + Description: role.Description, + OrganizationUnitID: role.OrganizationUnitID, + Permissions: role.Permissions, + Assignments: httpAssignments, + } +} + +// toHTTPLinks converts service layer Links to HTTP LinkResponse. +func toHTTPLinks(links []Link) []LinkResponse { + httpLinks := make([]LinkResponse, len(links)) + for i, link := range links { + httpLinks[i] = LinkResponse(link) + } + return httpLinks +} + +// toRoleAssignments converts HTTP AssignmentsRequest to service layer RoleAssignments. +func (rh *roleHandler) toRoleAssignments(req AssignmentsRequest) []RoleAssignment { + serviceAssignments := make([]RoleAssignment, len(req.Assignments)) + for i, a := range req.Assignments { + serviceAssignments[i] = RoleAssignment(a) + } + return serviceAssignments +} diff --git a/backend/internal/role/handler_test.go b/backend/internal/role/handler_test.go new file mode 100644 index 000000000..10975f983 --- /dev/null +++ b/backend/internal/role/handler_test.go @@ -0,0 +1,743 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/asgardeo/thunder/internal/system/log" +) + +type RoleHandlerTestSuite struct { + suite.Suite + mockService *RoleServiceInterfaceMock + handler *roleHandler +} + +func TestRoleHandlerTestSuite(t *testing.T) { + suite.Run(t, new(RoleHandlerTestSuite)) +} + +func (suite *RoleHandlerTestSuite) SetupTest() { + suite.mockService = NewRoleServiceInterfaceMock(suite.T()) + suite.handler = newRoleHandler(suite.mockService) +} + +// HandleRoleListRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRoleListRequest_Success() { + expectedResponse := &RoleList{ + TotalResults: 2, + StartIndex: 1, + Count: 2, + Roles: []Role{ + {ID: "role1", Name: "Admin"}, + {ID: "role2", Name: "User"}, + }, + Links: []Link{}, + } + + suite.mockService.On("GetRoleList", 10, 0).Return(expectedResponse, nil) + + req := httptest.NewRequest(http.MethodGet, "/roles?limit=10&offset=0", nil) + w := httptest.NewRecorder() + + suite.handler.HandleRoleListRequest(w, req) + + suite.Equal(http.StatusOK, w.Code) + + var response RoleListResponse + err := json.NewDecoder(w.Body).Decode(&response) + suite.NoError(err) + suite.Equal(2, response.TotalResults) + suite.Equal(2, len(response.Roles)) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleListRequest_DefaultPagination() { + expectedResponse := &RoleList{ + TotalResults: 1, + StartIndex: 1, + Count: 1, + Roles: []Role{{ID: "role1", Name: "Admin"}}, + Links: []Link{}, + } + + suite.mockService.On("GetRoleList", 30, 0).Return(expectedResponse, nil) + + req := httptest.NewRequest(http.MethodGet, "/roles", nil) + w := httptest.NewRecorder() + + suite.handler.HandleRoleListRequest(w, req) + + suite.Equal(http.StatusOK, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleListRequest_ServiceError() { + suite.mockService.On("GetRoleList", 10, 0).Return(nil, &ErrorInvalidLimit) + + req := httptest.NewRequest(http.MethodGet, "/roles?limit=10&offset=0", nil) + w := httptest.NewRecorder() + + suite.handler.HandleRoleListRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// HandleRolePostRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRolePostRequest_Success() { + request := CreateRoleRequest{ + Name: "Test Role", + Description: "Description", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + expectedRole := &RoleWithPermissionsAndAssignments{ + ID: "role1", + Name: "Test Role", + Description: "Description", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockService.On("CreateRole", mock.AnythingOfType("RoleCreationDetail")).Return(expectedRole, nil) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + suite.handler.HandleRolePostRequest(w, req) + + suite.Equal(http.StatusCreated, w.Code) + + var response CreateRoleResponse + err := json.NewDecoder(w.Body).Decode(&response) + suite.NoError(err) + suite.Equal("role1", response.ID) + suite.Equal("Test Role", response.Name) +} + +func (suite *RoleHandlerTestSuite) TestHandleRolePostRequest_InvalidJSON() { + req := httptest.NewRequest(http.MethodPost, "/roles", bytes.NewBufferString("invalid json")) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + suite.handler.HandleRolePostRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRolePostRequest_ServiceError() { + request := CreateRoleRequest{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockService.On("CreateRole", mock.AnythingOfType("RoleCreationDetail")). + Return(nil, &ErrorOrganizationUnitNotFound) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + suite.handler.HandleRolePostRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// HandleRoleGetRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRoleGetRequest_Success() { + expectedRole := &RoleWithPermissions{ + ID: "role1", + Name: "Admin", + Description: "Admin role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockService.On("GetRoleWithPermissions", "role1").Return(expectedRole, nil) + + req := httptest.NewRequest(http.MethodGet, "/roles/role1", nil) + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleGetRequest(w, req) + + suite.Equal(http.StatusOK, w.Code) + + var response RoleResponse + err := json.NewDecoder(w.Body).Decode(&response) + suite.NoError(err) + suite.Equal("role1", response.ID) + suite.Equal("Admin", response.Name) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleGetRequest_MissingID() { + suite.mockService.On("GetRoleWithPermissions", "").Return(nil, &ErrorMissingRoleID) + + req := httptest.NewRequest(http.MethodGet, "/roles/", nil) + w := httptest.NewRecorder() + + suite.handler.HandleRoleGetRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleGetRequest_NotFound() { + suite.mockService.On("GetRoleWithPermissions", "nonexistent").Return(nil, &ErrorRoleNotFound) + + req := httptest.NewRequest(http.MethodGet, "/roles/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + suite.handler.HandleRoleGetRequest(w, req) + + suite.Equal(http.StatusNotFound, w.Code) +} + +// HandleRolePutRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRolePutRequest_Success() { + request := UpdateRoleRequest{ + Name: "Updated Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1", "perm2"}, + } + + updatedRole := &RoleWithPermissions{ + ID: "role1", + Name: "Updated Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1", "perm2"}, + } + + suite.mockService.On("UpdateRoleWithPermissions", "role1", + mock.AnythingOfType("RoleUpdateDetail")).Return(updatedRole, nil) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPut, "/roles/role1", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRolePutRequest(w, req) + + suite.Equal(http.StatusOK, w.Code) + + var response RoleResponse + err := json.NewDecoder(w.Body).Decode(&response) + suite.NoError(err) + suite.Equal("Updated Role", response.Name) +} + +func (suite *RoleHandlerTestSuite) TestHandleRolePutRequest_InvalidJSON() { + req := httptest.NewRequest(http.MethodPut, "/roles/role1", bytes.NewBufferString("invalid")) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRolePutRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// HandleRoleDeleteRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRoleDeleteRequest_Success() { + suite.mockService.On("DeleteRole", "role1").Return(nil) + + req := httptest.NewRequest(http.MethodDelete, "/roles/role1", nil) + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleDeleteRequest(w, req) + + suite.Equal(http.StatusNoContent, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleDeleteRequest_RoleHasAssignments() { + suite.mockService.On("DeleteRole", "role1").Return(&ErrorCannotDeleteRole) + + req := httptest.NewRequest(http.MethodDelete, "/roles/role1", nil) + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleDeleteRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// HandleRoleAssignmentsGetRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRoleAssignmentsGetRequest_Success() { + expectedResponse := &AssignmentList{ + TotalResults: 2, + StartIndex: 1, + Count: 2, + Assignments: []RoleAssignmentWithDisplay{ + {ID: "user1", Type: AssigneeTypeUser}, + {ID: "group1", Type: AssigneeTypeGroup}, + }, + Links: []Link{}, + } + + suite.mockService.On("GetRoleAssignments", "role1", 10, 0, false).Return(expectedResponse, nil) + + req := httptest.NewRequest(http.MethodGet, "/roles/role1/assignments?limit=10&offset=0", nil) + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAssignmentsGetRequest(w, req) + + suite.Equal(http.StatusOK, w.Code) + + var response AssignmentListResponse + err := json.NewDecoder(w.Body).Decode(&response) + suite.NoError(err) + suite.Equal(2, response.TotalResults) + suite.Equal(2, len(response.Assignments)) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleAssignmentsGetRequest_RoleNotFound() { + suite.mockService.On("GetRoleAssignments", "nonexistent", 30, 0, false).Return(nil, &ErrorRoleNotFound) + + req := httptest.NewRequest(http.MethodGet, "/roles/nonexistent/assignments", nil) + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAssignmentsGetRequest(w, req) + + suite.Equal(http.StatusNotFound, w.Code) +} + +// HandleRoleAddAssignmentsRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRoleAddAssignmentsRequest_Success() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("AddAssignments", "role1", mock.AnythingOfType("[]role.RoleAssignment")).Return(nil) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles/role1/add-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAddAssignmentsRequest(w, req) + + suite.Equal(http.StatusNoContent, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleAddAssignmentsRequest_InvalidJSON() { + req := httptest.NewRequest(http.MethodPost, "/roles/role1/add-assignments", bytes.NewBufferString("invalid")) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAddAssignmentsRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleAddAssignmentsRequest_ServiceError() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "invalid_user", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("AddAssignments", "role1", mock.AnythingOfType("[]role.RoleAssignment")). + Return(&ErrorInvalidAssignmentID) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles/role1/add-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAddAssignmentsRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// HandleRoleRemoveAssignmentsRequest Tests +func (suite *RoleHandlerTestSuite) TestHandleRoleRemoveAssignmentsRequest_Success() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("RemoveAssignments", "role1", mock.AnythingOfType("[]role.RoleAssignment")).Return(nil) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles/role1/remove-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleRemoveAssignmentsRequest(w, req) + + suite.Equal(http.StatusNoContent, w.Code) +} + +// ParsePaginationParams Tests +func (suite *RoleHandlerTestSuite) TestParsePaginationParams() { + testCases := []struct { + name string + queryString string + expectedLimit int + expectedOffset int + expectError bool + }{ + { + name: "ValidParams", + queryString: "limit=20&offset=10", + expectedLimit: 20, + expectedOffset: 10, + expectError: false, + }, + { + name: "DefaultLimit", + queryString: "offset=5", + expectedLimit: 30, + expectedOffset: 5, + expectError: false, + }, + { + name: "NoParams", + queryString: "", + expectedLimit: 30, + expectedOffset: 0, + expectError: false, + }, + { + name: "InvalidLimit", + queryString: "limit=abc", + expectedLimit: 0, + expectedOffset: 0, + expectError: true, + }, + { + name: "InvalidOffset", + queryString: "offset=xyz", + expectedLimit: 0, + expectedOffset: 0, + expectError: true, + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + query, _ := url.ParseQuery(tc.queryString) + limit, offset, err := parsePaginationParams(query) + + if tc.expectError { + suite.NotNil(err) + } else { + suite.Nil(err) + suite.Equal(tc.expectedLimit, limit) + suite.Equal(tc.expectedOffset, offset) + } + }) + } +} + +// HandleRolePutRequest additional tests +func (suite *RoleHandlerTestSuite) TestHandleRolePutRequest_MissingID() { + request := UpdateRoleRequest{ + Name: "Updated Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockService.On("UpdateRoleWithPermissions", "", mock.AnythingOfType("RoleUpdateDetail")). + Return(nil, &ErrorMissingRoleID) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPut, "/roles/", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + suite.handler.HandleRolePutRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRolePutRequest_RoleNotFound() { + request := UpdateRoleRequest{ + Name: "Updated Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockService.On("UpdateRoleWithPermissions", "nonexistent", mock.AnythingOfType("RoleUpdateDetail")). + Return(nil, &ErrorRoleNotFound) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPut, "/roles/nonexistent", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + suite.handler.HandleRolePutRequest(w, req) + + suite.Equal(http.StatusNotFound, w.Code) +} + +// HandleRoleDeleteRequest additional tests +func (suite *RoleHandlerTestSuite) TestHandleRoleDeleteRequest_MissingID() { + suite.mockService.On("DeleteRole", "").Return(&ErrorMissingRoleID) + + req := httptest.NewRequest(http.MethodDelete, "/roles/", nil) + w := httptest.NewRecorder() + + suite.handler.HandleRoleDeleteRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleDeleteRequest_RoleNotFound() { + suite.mockService.On("DeleteRole", "nonexistent").Return(&ErrorRoleNotFound) + + req := httptest.NewRequest(http.MethodDelete, "/roles/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + suite.handler.HandleRoleDeleteRequest(w, req) + + suite.Equal(http.StatusNotFound, w.Code) +} + +// HandleRoleAssignmentsGetRequest additional tests +func (suite *RoleHandlerTestSuite) TestHandleRoleAssignmentsGetRequest_MissingID() { + suite.mockService.On("GetRoleAssignments", "", 30, 0, false).Return(nil, &ErrorMissingRoleID) + + req := httptest.NewRequest(http.MethodGet, "/roles//assignments", nil) + w := httptest.NewRecorder() + + suite.handler.HandleRoleAssignmentsGetRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleAssignmentsGetRequest_InvalidPagination() { + req := httptest.NewRequest(http.MethodGet, "/roles/role1/assignments?limit=invalid", nil) + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAssignmentsGetRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// HandleRoleAddAssignmentsRequest additional tests +func (suite *RoleHandlerTestSuite) TestHandleRoleAddAssignmentsRequest_MissingID() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("AddAssignments", "", mock.AnythingOfType("[]role.RoleAssignment")). + Return(&ErrorMissingRoleID) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles//add-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAddAssignmentsRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleAddAssignmentsRequest_RoleNotFound() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("AddAssignments", "nonexistent", mock.AnythingOfType("[]role.RoleAssignment")). + Return(&ErrorRoleNotFound) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles/nonexistent/add-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + suite.handler.HandleRoleAddAssignmentsRequest(w, req) + + suite.Equal(http.StatusNotFound, w.Code) +} + +// HandleRoleRemoveAssignmentsRequest additional tests +func (suite *RoleHandlerTestSuite) TestHandleRoleRemoveAssignmentsRequest_MissingID() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("RemoveAssignments", "", mock.AnythingOfType("[]role.RoleAssignment")). + Return(&ErrorMissingRoleID) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles//remove-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + suite.handler.HandleRoleRemoveAssignmentsRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleRemoveAssignmentsRequest_InvalidJSON() { + req := httptest.NewRequest(http.MethodPost, "/roles/role1/remove-assignments", bytes.NewBufferString("invalid")) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleRemoveAssignmentsRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleRemoveAssignmentsRequest_RoleNotFound() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("RemoveAssignments", "nonexistent", mock.AnythingOfType("[]role.RoleAssignment")). + Return(&ErrorRoleNotFound) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles/nonexistent/remove-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + suite.handler.HandleRoleRemoveAssignmentsRequest(w, req) + + suite.Equal(http.StatusNotFound, w.Code) +} + +func (suite *RoleHandlerTestSuite) TestHandleRoleRemoveAssignmentsRequest_ServiceError() { + request := AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: "user1", Type: AssigneeTypeUser}, + }, + } + + suite.mockService.On("RemoveAssignments", "role1", mock.AnythingOfType("[]role.RoleAssignment")). + Return(&ErrorInvalidAssignmentID) + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/roles/role1/remove-assignments", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.SetPathValue("id", "role1") + w := httptest.NewRecorder() + + suite.handler.HandleRoleRemoveAssignmentsRequest(w, req) + + suite.Equal(http.StatusBadRequest, w.Code) +} + +// Sanitization Tests +func (suite *RoleHandlerTestSuite) TestSanitizeCreateRoleRequest() { + request := &CreateRoleRequest{ + Name: " Test Role ", + Description: " Description ", + OrganizationUnitID: " ou1 ", + Permissions: []string{" perm1 ", " perm2 "}, + Assignments: []AssignmentRequest{ + {ID: " user1 ", Type: AssigneeTypeUser}, + }, + } + + sanitized := suite.handler.sanitizeCreateRoleRequest(request) + + suite.Equal("Test Role", sanitized.Name) + suite.Equal("Description", sanitized.Description) + suite.Equal("ou1", sanitized.OrganizationUnitID) + suite.Equal("perm1", sanitized.Permissions[0]) + suite.Equal("user1", sanitized.Assignments[0].ID) +} + +func (suite *RoleHandlerTestSuite) TestSanitizeUpdateRoleRequest() { + request := &UpdateRoleRequest{ + Name: " Updated Name ", + OrganizationUnitID: " ou2 ", + Permissions: []string{" perm3 "}, + } + + sanitized := suite.handler.sanitizeUpdateRoleRequest(request) + + suite.Equal("Updated Name", sanitized.Name) + suite.Equal("ou2", sanitized.OrganizationUnitID) + suite.Equal("perm3", sanitized.Permissions[0]) +} + +func (suite *RoleHandlerTestSuite) TestSanitizeAssignmentsRequest() { + request := &AssignmentsRequest{ + Assignments: []AssignmentRequest{ + {ID: " group1 ", Type: AssigneeTypeGroup}, + }, + } + + sanitized := suite.handler.sanitizeAssignmentsRequest(request) + + suite.Equal("group1", sanitized.Assignments[0].ID) + suite.Equal(AssigneeTypeGroup, sanitized.Assignments[0].Type) +} + +func (suite *RoleHandlerTestSuite) TestwriteToResponse_Success() { + response := &RoleResponse{ + ID: "role1", + Name: "Role 1", + Description: "A sample role", + OrganizationUnitID: "ou1", + } + + w := httptest.NewRecorder() + isErr := writeToResponse(w, response, log.GetLogger()) + suite.False(isErr) +} + +func (suite *RoleHandlerTestSuite) TestwriteToResponse_Error() { + // Use a function which cannot be marshaled to JSON to cause encoding error + response := func() {} + w := httptest.NewRecorder() + isErr := writeToResponse(w, response, log.GetLogger()) + suite.True(isErr) +} diff --git a/backend/internal/role/init.go b/backend/internal/role/init.go new file mode 100644 index 000000000..88c1aa245 --- /dev/null +++ b/backend/internal/role/init.go @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "net/http" + "strings" + + "github.com/asgardeo/thunder/internal/group" + oupkg "github.com/asgardeo/thunder/internal/ou" + "github.com/asgardeo/thunder/internal/system/middleware" + "github.com/asgardeo/thunder/internal/user" +) + +// Initialize initializes the role service and registers its routes. +func Initialize( + mux *http.ServeMux, + userService user.UserServiceInterface, + groupService group.GroupServiceInterface, + ouService oupkg.OrganizationUnitServiceInterface, +) RoleServiceInterface { + roleStore := newRoleStore() + roleService := newRoleService(roleStore, userService, groupService, ouService) + roleHandler := newRoleHandler(roleService) + registerRoutes(mux, roleHandler) + return roleService +} + +// registerRoutes registers the routes for role management operations. +func registerRoutes(mux *http.ServeMux, roleHandler *roleHandler) { + opts1 := middleware.CORSOptions{ + AllowedMethods: "GET, POST", + AllowedHeaders: "Content-Type, Authorization", + AllowCredentials: true, + } + mux.HandleFunc(middleware.WithCORS("POST /roles", roleHandler.HandleRolePostRequest, opts1)) + mux.HandleFunc(middleware.WithCORS("GET /roles", roleHandler.HandleRoleListRequest, opts1)) + mux.HandleFunc(middleware.WithCORS("OPTIONS /roles", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }, opts1)) + + opts2 := middleware.CORSOptions{ + AllowedMethods: "GET, PUT, DELETE", + AllowedHeaders: "Content-Type, Authorization", + AllowCredentials: true, + } + // Special handling for /roles/{id} and /roles/{id}/assignments + mux.HandleFunc(middleware.WithCORS("GET /roles/", + func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/roles/") + segments := strings.Split(path, "/") + r.SetPathValue("id", segments[0]) + + if len(segments) == 1 { + roleHandler.HandleRoleGetRequest(w, r) + } else if len(segments) == 2 && segments[1] == "assignments" { + roleHandler.HandleRoleAssignmentsGetRequest(w, r) + } else { + http.NotFound(w, r) + } + }, opts2)) + mux.HandleFunc(middleware.WithCORS("PUT /roles/{id}", roleHandler.HandleRolePutRequest, opts2)) + mux.HandleFunc(middleware.WithCORS("DELETE /roles/{id}", roleHandler.HandleRoleDeleteRequest, opts2)) + mux.HandleFunc(middleware.WithCORS("OPTIONS /roles/{id}", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }, opts2)) + + opts3 := middleware.CORSOptions{ + AllowedMethods: "POST", + AllowedHeaders: "Content-Type, Authorization", + AllowCredentials: true, + } + mux.HandleFunc(middleware.WithCORS("POST /roles/{id}/assignments/add", + roleHandler.HandleRoleAddAssignmentsRequest, opts3)) + mux.HandleFunc(middleware.WithCORS("POST /roles/{id}/assignments/remove", + roleHandler.HandleRoleRemoveAssignmentsRequest, opts3)) + mux.HandleFunc(middleware.WithCORS("OPTIONS /roles/{id}/assignments/add", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }, opts3)) + mux.HandleFunc(middleware.WithCORS("OPTIONS /roles/{id}/assignments/remove", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }, opts3)) +} diff --git a/backend/internal/role/model.go b/backend/internal/role/model.go new file mode 100644 index 000000000..0dcab337f --- /dev/null +++ b/backend/internal/role/model.go @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +// AssigneeType represents the type of assignee entity. +type AssigneeType string + +const ( + // AssigneeTypeUser is the type for users. + AssigneeTypeUser AssigneeType = "user" + // AssigneeTypeGroup is the type for groups. + AssigneeTypeGroup AssigneeType = "group" +) + +// AssignmentResponse represents an assignment of a role to a user or group. +type AssignmentResponse struct { + ID string `json:"id"` + Type AssigneeType `json:"type"` + Display string `json:"display,omitempty"` +} + +// AssignmentRequest represents an assignment of a role to a user or group. +type AssignmentRequest struct { + ID string `json:"id"` + Type AssigneeType `json:"type"` +} + +// RoleSummaryResponse represents the basic information of a role. +type RoleSummaryResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` +} + +// RoleResponse represents a complete role with permissions. +type RoleResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` +} + +// CreateRoleRequest represents the request body for creating a role. +type CreateRoleRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` + Assignments []AssignmentRequest `json:"assignments,omitempty"` +} + +// CreateRoleResponse represents the response body for creating a role. +type CreateRoleResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` + Assignments []AssignmentResponse `json:"assignments,omitempty"` +} + +// UpdateRoleRequest represents the request body for updating a role. +type UpdateRoleRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` +} + +// AssignmentsRequest represents the request body for adding or removing assignments. +type AssignmentsRequest struct { + Assignments []AssignmentRequest `json:"assignments"` +} + +// LinkResponse represents a pagination link. +type LinkResponse struct { + Href string `json:"href"` + Rel string `json:"rel"` +} + +// RoleListResponse represents the response for listing roles with pagination. +type RoleListResponse struct { + TotalResults int `json:"totalResults"` + StartIndex int `json:"startIndex"` + Count int `json:"count"` + Roles []RoleSummaryResponse `json:"roles"` + Links []LinkResponse `json:"links"` +} + +// AssignmentListResponse represents the response for listing role assignments with pagination. +type AssignmentListResponse struct { + TotalResults int `json:"totalResults"` + StartIndex int `json:"startIndex"` + Count int `json:"count"` + Assignments []AssignmentResponse `json:"assignments"` + Links []LinkResponse `json:"links"` +} + +// Internal service layer structs - used for business logic processing + +// RoleCreationDetail represents the parameters for creating a role. +type RoleCreationDetail struct { + Name string + Description string + OrganizationUnitID string + Permissions []string + Assignments []RoleAssignment +} + +// RoleWithPermissionsAndAssignments represents the parameters for creating a role. +type RoleWithPermissionsAndAssignments struct { + ID string + Name string + Description string + OrganizationUnitID string + Permissions []string + Assignments []RoleAssignment +} + +// RoleAssignment represents an assignment used internally by the service layer. +type RoleAssignment struct { + ID string + Type AssigneeType +} + +// RoleAssignmentWithDisplay represents an assignment used internally by the service layer. +type RoleAssignmentWithDisplay struct { + ID string + Type AssigneeType + Display string +} + +// Role represents basic role information used internally by the service layer. +type Role struct { + ID string + Name string + Description string + OrganizationUnitID string +} + +// RoleWithPermissions represents complete role details used internally by the service layer. +type RoleWithPermissions struct { + ID string + Name string + Description string + OrganizationUnitID string + Permissions []string +} + +// RoleUpdateDetail represents the parameters for creating a role. +type RoleUpdateDetail struct { + Name string + Description string + OrganizationUnitID string + Permissions []string +} + +// Link represents a pagination link. +type Link struct { + Href string + Rel string +} + +// RoleList represents the result of listing roles. +type RoleList struct { + TotalResults int + StartIndex int + Count int + Roles []Role + Links []Link +} + +// AssignmentList represents the result of listing role assignments. +type AssignmentList struct { + TotalResults int + StartIndex int + Count int + Assignments []RoleAssignmentWithDisplay + Links []Link +} diff --git a/backend/internal/role/roleStoreInterface_mock_test.go b/backend/internal/role/roleStoreInterface_mock_test.go new file mode 100644 index 000000000..c4595b4db --- /dev/null +++ b/backend/internal/role/roleStoreInterface_mock_test.go @@ -0,0 +1,902 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package role + +import ( + mock "github.com/stretchr/testify/mock" +) + +// newRoleStoreInterfaceMock creates a new instance of roleStoreInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newRoleStoreInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *roleStoreInterfaceMock { + mock := &roleStoreInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// roleStoreInterfaceMock is an autogenerated mock type for the roleStoreInterface type +type roleStoreInterfaceMock struct { + mock.Mock +} + +type roleStoreInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *roleStoreInterfaceMock) EXPECT() *roleStoreInterfaceMock_Expecter { + return &roleStoreInterfaceMock_Expecter{mock: &_m.Mock} +} + +// AddAssignments provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) AddAssignments(id string, assignments []RoleAssignment) error { + ret := _mock.Called(id, assignments) + + if len(ret) == 0 { + panic("no return value specified for AddAssignments") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, []RoleAssignment) error); ok { + r0 = returnFunc(id, assignments) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// roleStoreInterfaceMock_AddAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddAssignments' +type roleStoreInterfaceMock_AddAssignments_Call struct { + *mock.Call +} + +// AddAssignments is a helper method to define mock.On call +// - id string +// - assignments []RoleAssignment +func (_e *roleStoreInterfaceMock_Expecter) AddAssignments(id interface{}, assignments interface{}) *roleStoreInterfaceMock_AddAssignments_Call { + return &roleStoreInterfaceMock_AddAssignments_Call{Call: _e.mock.On("AddAssignments", id, assignments)} +} + +func (_c *roleStoreInterfaceMock_AddAssignments_Call) Run(run func(id string, assignments []RoleAssignment)) *roleStoreInterfaceMock_AddAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []RoleAssignment + if args[1] != nil { + arg1 = args[1].([]RoleAssignment) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_AddAssignments_Call) Return(err error) *roleStoreInterfaceMock_AddAssignments_Call { + _c.Call.Return(err) + return _c +} + +func (_c *roleStoreInterfaceMock_AddAssignments_Call) RunAndReturn(run func(id string, assignments []RoleAssignment) error) *roleStoreInterfaceMock_AddAssignments_Call { + _c.Call.Return(run) + return _c +} + +// CheckRoleNameExists provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) CheckRoleNameExists(ouID string, name string) (bool, error) { + ret := _mock.Called(ouID, name) + + if len(ret) == 0 { + panic("no return value specified for CheckRoleNameExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, string) (bool, error)); ok { + return returnFunc(ouID, name) + } + if returnFunc, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = returnFunc(ouID, name) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(string, string) error); ok { + r1 = returnFunc(ouID, name) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_CheckRoleNameExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckRoleNameExists' +type roleStoreInterfaceMock_CheckRoleNameExists_Call struct { + *mock.Call +} + +// CheckRoleNameExists is a helper method to define mock.On call +// - ouID string +// - name string +func (_e *roleStoreInterfaceMock_Expecter) CheckRoleNameExists(ouID interface{}, name interface{}) *roleStoreInterfaceMock_CheckRoleNameExists_Call { + return &roleStoreInterfaceMock_CheckRoleNameExists_Call{Call: _e.mock.On("CheckRoleNameExists", ouID, name)} +} + +func (_c *roleStoreInterfaceMock_CheckRoleNameExists_Call) Run(run func(ouID string, name string)) *roleStoreInterfaceMock_CheckRoleNameExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_CheckRoleNameExists_Call) Return(b bool, err error) *roleStoreInterfaceMock_CheckRoleNameExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *roleStoreInterfaceMock_CheckRoleNameExists_Call) RunAndReturn(run func(ouID string, name string) (bool, error)) *roleStoreInterfaceMock_CheckRoleNameExists_Call { + _c.Call.Return(run) + return _c +} + +// CheckRoleNameExistsExcludingID provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) CheckRoleNameExistsExcludingID(ouID string, name string, excludeRoleID string) (bool, error) { + ret := _mock.Called(ouID, name, excludeRoleID) + + if len(ret) == 0 { + panic("no return value specified for CheckRoleNameExistsExcludingID") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, string, string) (bool, error)); ok { + return returnFunc(ouID, name, excludeRoleID) + } + if returnFunc, ok := ret.Get(0).(func(string, string, string) bool); ok { + r0 = returnFunc(ouID, name, excludeRoleID) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(string, string, string) error); ok { + r1 = returnFunc(ouID, name, excludeRoleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckRoleNameExistsExcludingID' +type roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call struct { + *mock.Call +} + +// CheckRoleNameExistsExcludingID is a helper method to define mock.On call +// - ouID string +// - name string +// - excludeRoleID string +func (_e *roleStoreInterfaceMock_Expecter) CheckRoleNameExistsExcludingID(ouID interface{}, name interface{}, excludeRoleID interface{}) *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call { + return &roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call{Call: _e.mock.On("CheckRoleNameExistsExcludingID", ouID, name, excludeRoleID)} +} + +func (_c *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call) Run(run func(ouID string, name string, excludeRoleID string)) *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call) Return(b bool, err error) *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call) RunAndReturn(run func(ouID string, name string, excludeRoleID string) (bool, error)) *roleStoreInterfaceMock_CheckRoleNameExistsExcludingID_Call { + _c.Call.Return(run) + return _c +} + +// CreateRole provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) CreateRole(id string, role RoleCreationDetail) error { + ret := _mock.Called(id, role) + + if len(ret) == 0 { + panic("no return value specified for CreateRole") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, RoleCreationDetail) error); ok { + r0 = returnFunc(id, role) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// roleStoreInterfaceMock_CreateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRole' +type roleStoreInterfaceMock_CreateRole_Call struct { + *mock.Call +} + +// CreateRole is a helper method to define mock.On call +// - id string +// - role RoleCreationDetail +func (_e *roleStoreInterfaceMock_Expecter) CreateRole(id interface{}, role interface{}) *roleStoreInterfaceMock_CreateRole_Call { + return &roleStoreInterfaceMock_CreateRole_Call{Call: _e.mock.On("CreateRole", id, role)} +} + +func (_c *roleStoreInterfaceMock_CreateRole_Call) Run(run func(id string, role RoleCreationDetail)) *roleStoreInterfaceMock_CreateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 RoleCreationDetail + if args[1] != nil { + arg1 = args[1].(RoleCreationDetail) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_CreateRole_Call) Return(err error) *roleStoreInterfaceMock_CreateRole_Call { + _c.Call.Return(err) + return _c +} + +func (_c *roleStoreInterfaceMock_CreateRole_Call) RunAndReturn(run func(id string, role RoleCreationDetail) error) *roleStoreInterfaceMock_CreateRole_Call { + _c.Call.Return(run) + return _c +} + +// DeleteRole provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) DeleteRole(id string) error { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for DeleteRole") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string) error); ok { + r0 = returnFunc(id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// roleStoreInterfaceMock_DeleteRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteRole' +type roleStoreInterfaceMock_DeleteRole_Call struct { + *mock.Call +} + +// DeleteRole is a helper method to define mock.On call +// - id string +func (_e *roleStoreInterfaceMock_Expecter) DeleteRole(id interface{}) *roleStoreInterfaceMock_DeleteRole_Call { + return &roleStoreInterfaceMock_DeleteRole_Call{Call: _e.mock.On("DeleteRole", id)} +} + +func (_c *roleStoreInterfaceMock_DeleteRole_Call) Run(run func(id string)) *roleStoreInterfaceMock_DeleteRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_DeleteRole_Call) Return(err error) *roleStoreInterfaceMock_DeleteRole_Call { + _c.Call.Return(err) + return _c +} + +func (_c *roleStoreInterfaceMock_DeleteRole_Call) RunAndReturn(run func(id string) error) *roleStoreInterfaceMock_DeleteRole_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthorizedPermissions provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) GetAuthorizedPermissions(userID string, groupIDs []string, requestedPermissions []string) ([]string, error) { + ret := _mock.Called(userID, groupIDs, requestedPermissions) + + if len(ret) == 0 { + panic("no return value specified for GetAuthorizedPermissions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, []string, []string) ([]string, error)); ok { + return returnFunc(userID, groupIDs, requestedPermissions) + } + if returnFunc, ok := ret.Get(0).(func(string, []string, []string) []string); ok { + r0 = returnFunc(userID, groupIDs, requestedPermissions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(string, []string, []string) error); ok { + r1 = returnFunc(userID, groupIDs, requestedPermissions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_GetAuthorizedPermissions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthorizedPermissions' +type roleStoreInterfaceMock_GetAuthorizedPermissions_Call struct { + *mock.Call +} + +// GetAuthorizedPermissions is a helper method to define mock.On call +// - userID string +// - groupIDs []string +// - requestedPermissions []string +func (_e *roleStoreInterfaceMock_Expecter) GetAuthorizedPermissions(userID interface{}, groupIDs interface{}, requestedPermissions interface{}) *roleStoreInterfaceMock_GetAuthorizedPermissions_Call { + return &roleStoreInterfaceMock_GetAuthorizedPermissions_Call{Call: _e.mock.On("GetAuthorizedPermissions", userID, groupIDs, requestedPermissions)} +} + +func (_c *roleStoreInterfaceMock_GetAuthorizedPermissions_Call) Run(run func(userID string, groupIDs []string, requestedPermissions []string)) *roleStoreInterfaceMock_GetAuthorizedPermissions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_GetAuthorizedPermissions_Call) Return(strings []string, err error) *roleStoreInterfaceMock_GetAuthorizedPermissions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *roleStoreInterfaceMock_GetAuthorizedPermissions_Call) RunAndReturn(run func(userID string, groupIDs []string, requestedPermissions []string) ([]string, error)) *roleStoreInterfaceMock_GetAuthorizedPermissions_Call { + _c.Call.Return(run) + return _c +} + +// GetRole provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) GetRole(id string) (RoleWithPermissions, error) { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for GetRole") + } + + var r0 RoleWithPermissions + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (RoleWithPermissions, error)); ok { + return returnFunc(id) + } + if returnFunc, ok := ret.Get(0).(func(string) RoleWithPermissions); ok { + r0 = returnFunc(id) + } else { + r0 = ret.Get(0).(RoleWithPermissions) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_GetRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRole' +type roleStoreInterfaceMock_GetRole_Call struct { + *mock.Call +} + +// GetRole is a helper method to define mock.On call +// - id string +func (_e *roleStoreInterfaceMock_Expecter) GetRole(id interface{}) *roleStoreInterfaceMock_GetRole_Call { + return &roleStoreInterfaceMock_GetRole_Call{Call: _e.mock.On("GetRole", id)} +} + +func (_c *roleStoreInterfaceMock_GetRole_Call) Run(run func(id string)) *roleStoreInterfaceMock_GetRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRole_Call) Return(roleWithPermissions RoleWithPermissions, err error) *roleStoreInterfaceMock_GetRole_Call { + _c.Call.Return(roleWithPermissions, err) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRole_Call) RunAndReturn(run func(id string) (RoleWithPermissions, error)) *roleStoreInterfaceMock_GetRole_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleAssignments provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) GetRoleAssignments(id string, limit int, offset int) ([]RoleAssignment, error) { + ret := _mock.Called(id, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetRoleAssignments") + } + + var r0 []RoleAssignment + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, int, int) ([]RoleAssignment, error)); ok { + return returnFunc(id, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(string, int, int) []RoleAssignment); ok { + r0 = returnFunc(id, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]RoleAssignment) + } + } + if returnFunc, ok := ret.Get(1).(func(string, int, int) error); ok { + r1 = returnFunc(id, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_GetRoleAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleAssignments' +type roleStoreInterfaceMock_GetRoleAssignments_Call struct { + *mock.Call +} + +// GetRoleAssignments is a helper method to define mock.On call +// - id string +// - limit int +// - offset int +func (_e *roleStoreInterfaceMock_Expecter) GetRoleAssignments(id interface{}, limit interface{}, offset interface{}) *roleStoreInterfaceMock_GetRoleAssignments_Call { + return &roleStoreInterfaceMock_GetRoleAssignments_Call{Call: _e.mock.On("GetRoleAssignments", id, limit, offset)} +} + +func (_c *roleStoreInterfaceMock_GetRoleAssignments_Call) Run(run func(id string, limit int, offset int)) *roleStoreInterfaceMock_GetRoleAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleAssignments_Call) Return(roleAssignments []RoleAssignment, err error) *roleStoreInterfaceMock_GetRoleAssignments_Call { + _c.Call.Return(roleAssignments, err) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleAssignments_Call) RunAndReturn(run func(id string, limit int, offset int) ([]RoleAssignment, error)) *roleStoreInterfaceMock_GetRoleAssignments_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleAssignmentsCount provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) GetRoleAssignmentsCount(id string) (int, error) { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for GetRoleAssignmentsCount") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (int, error)); ok { + return returnFunc(id) + } + if returnFunc, ok := ret.Get(0).(func(string) int); ok { + r0 = returnFunc(id) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_GetRoleAssignmentsCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleAssignmentsCount' +type roleStoreInterfaceMock_GetRoleAssignmentsCount_Call struct { + *mock.Call +} + +// GetRoleAssignmentsCount is a helper method to define mock.On call +// - id string +func (_e *roleStoreInterfaceMock_Expecter) GetRoleAssignmentsCount(id interface{}) *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call { + return &roleStoreInterfaceMock_GetRoleAssignmentsCount_Call{Call: _e.mock.On("GetRoleAssignmentsCount", id)} +} + +func (_c *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call) Run(run func(id string)) *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call) Return(n int, err error) *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call) RunAndReturn(run func(id string) (int, error)) *roleStoreInterfaceMock_GetRoleAssignmentsCount_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleList provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) GetRoleList(limit int, offset int) ([]Role, error) { + ret := _mock.Called(limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetRoleList") + } + + var r0 []Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(int, int) ([]Role, error)); ok { + return returnFunc(limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(int, int) []Role); ok { + r0 = returnFunc(limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Role) + } + } + if returnFunc, ok := ret.Get(1).(func(int, int) error); ok { + r1 = returnFunc(limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_GetRoleList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleList' +type roleStoreInterfaceMock_GetRoleList_Call struct { + *mock.Call +} + +// GetRoleList is a helper method to define mock.On call +// - limit int +// - offset int +func (_e *roleStoreInterfaceMock_Expecter) GetRoleList(limit interface{}, offset interface{}) *roleStoreInterfaceMock_GetRoleList_Call { + return &roleStoreInterfaceMock_GetRoleList_Call{Call: _e.mock.On("GetRoleList", limit, offset)} +} + +func (_c *roleStoreInterfaceMock_GetRoleList_Call) Run(run func(limit int, offset int)) *roleStoreInterfaceMock_GetRoleList_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleList_Call) Return(roles []Role, err error) *roleStoreInterfaceMock_GetRoleList_Call { + _c.Call.Return(roles, err) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleList_Call) RunAndReturn(run func(limit int, offset int) ([]Role, error)) *roleStoreInterfaceMock_GetRoleList_Call { + _c.Call.Return(run) + return _c +} + +// GetRoleListCount provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) GetRoleListCount() (int, error) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetRoleListCount") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func() (int, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() int); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_GetRoleListCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRoleListCount' +type roleStoreInterfaceMock_GetRoleListCount_Call struct { + *mock.Call +} + +// GetRoleListCount is a helper method to define mock.On call +func (_e *roleStoreInterfaceMock_Expecter) GetRoleListCount() *roleStoreInterfaceMock_GetRoleListCount_Call { + return &roleStoreInterfaceMock_GetRoleListCount_Call{Call: _e.mock.On("GetRoleListCount")} +} + +func (_c *roleStoreInterfaceMock_GetRoleListCount_Call) Run(run func()) *roleStoreInterfaceMock_GetRoleListCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleListCount_Call) Return(n int, err error) *roleStoreInterfaceMock_GetRoleListCount_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *roleStoreInterfaceMock_GetRoleListCount_Call) RunAndReturn(run func() (int, error)) *roleStoreInterfaceMock_GetRoleListCount_Call { + _c.Call.Return(run) + return _c +} + +// IsRoleExist provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) IsRoleExist(id string) (bool, error) { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for IsRoleExist") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (bool, error)); ok { + return returnFunc(id) + } + if returnFunc, ok := ret.Get(0).(func(string) bool); ok { + r0 = returnFunc(id) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// roleStoreInterfaceMock_IsRoleExist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsRoleExist' +type roleStoreInterfaceMock_IsRoleExist_Call struct { + *mock.Call +} + +// IsRoleExist is a helper method to define mock.On call +// - id string +func (_e *roleStoreInterfaceMock_Expecter) IsRoleExist(id interface{}) *roleStoreInterfaceMock_IsRoleExist_Call { + return &roleStoreInterfaceMock_IsRoleExist_Call{Call: _e.mock.On("IsRoleExist", id)} +} + +func (_c *roleStoreInterfaceMock_IsRoleExist_Call) Run(run func(id string)) *roleStoreInterfaceMock_IsRoleExist_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_IsRoleExist_Call) Return(b bool, err error) *roleStoreInterfaceMock_IsRoleExist_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *roleStoreInterfaceMock_IsRoleExist_Call) RunAndReturn(run func(id string) (bool, error)) *roleStoreInterfaceMock_IsRoleExist_Call { + _c.Call.Return(run) + return _c +} + +// RemoveAssignments provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) RemoveAssignments(id string, assignments []RoleAssignment) error { + ret := _mock.Called(id, assignments) + + if len(ret) == 0 { + panic("no return value specified for RemoveAssignments") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, []RoleAssignment) error); ok { + r0 = returnFunc(id, assignments) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// roleStoreInterfaceMock_RemoveAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveAssignments' +type roleStoreInterfaceMock_RemoveAssignments_Call struct { + *mock.Call +} + +// RemoveAssignments is a helper method to define mock.On call +// - id string +// - assignments []RoleAssignment +func (_e *roleStoreInterfaceMock_Expecter) RemoveAssignments(id interface{}, assignments interface{}) *roleStoreInterfaceMock_RemoveAssignments_Call { + return &roleStoreInterfaceMock_RemoveAssignments_Call{Call: _e.mock.On("RemoveAssignments", id, assignments)} +} + +func (_c *roleStoreInterfaceMock_RemoveAssignments_Call) Run(run func(id string, assignments []RoleAssignment)) *roleStoreInterfaceMock_RemoveAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []RoleAssignment + if args[1] != nil { + arg1 = args[1].([]RoleAssignment) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_RemoveAssignments_Call) Return(err error) *roleStoreInterfaceMock_RemoveAssignments_Call { + _c.Call.Return(err) + return _c +} + +func (_c *roleStoreInterfaceMock_RemoveAssignments_Call) RunAndReturn(run func(id string, assignments []RoleAssignment) error) *roleStoreInterfaceMock_RemoveAssignments_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRole provides a mock function for the type roleStoreInterfaceMock +func (_mock *roleStoreInterfaceMock) UpdateRole(id string, role RoleUpdateDetail) error { + ret := _mock.Called(id, role) + + if len(ret) == 0 { + panic("no return value specified for UpdateRole") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, RoleUpdateDetail) error); ok { + r0 = returnFunc(id, role) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// roleStoreInterfaceMock_UpdateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRole' +type roleStoreInterfaceMock_UpdateRole_Call struct { + *mock.Call +} + +// UpdateRole is a helper method to define mock.On call +// - id string +// - role RoleUpdateDetail +func (_e *roleStoreInterfaceMock_Expecter) UpdateRole(id interface{}, role interface{}) *roleStoreInterfaceMock_UpdateRole_Call { + return &roleStoreInterfaceMock_UpdateRole_Call{Call: _e.mock.On("UpdateRole", id, role)} +} + +func (_c *roleStoreInterfaceMock_UpdateRole_Call) Run(run func(id string, role RoleUpdateDetail)) *roleStoreInterfaceMock_UpdateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 RoleUpdateDetail + if args[1] != nil { + arg1 = args[1].(RoleUpdateDetail) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *roleStoreInterfaceMock_UpdateRole_Call) Return(err error) *roleStoreInterfaceMock_UpdateRole_Call { + _c.Call.Return(err) + return _c +} + +func (_c *roleStoreInterfaceMock_UpdateRole_Call) RunAndReturn(run func(id string, role RoleUpdateDetail) error) *roleStoreInterfaceMock_UpdateRole_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/internal/role/service.go b/backend/internal/role/service.go new file mode 100644 index 000000000..71397c604 --- /dev/null +++ b/backend/internal/role/service.go @@ -0,0 +1,645 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Package role provides role management functionality. +package role + +import ( + "errors" + "fmt" + + "github.com/asgardeo/thunder/internal/group" + oupkg "github.com/asgardeo/thunder/internal/ou" + serverconst "github.com/asgardeo/thunder/internal/system/constants" + "github.com/asgardeo/thunder/internal/system/error/serviceerror" + "github.com/asgardeo/thunder/internal/system/log" + "github.com/asgardeo/thunder/internal/system/utils" + "github.com/asgardeo/thunder/internal/user" +) + +const loggerComponentName = "RoleMgtService" + +// RoleServiceInterface defines the interface for the role service. +type RoleServiceInterface interface { + GetRoleList(limit, offset int) (*RoleList, *serviceerror.ServiceError) + CreateRole(role RoleCreationDetail) (*RoleWithPermissionsAndAssignments, *serviceerror.ServiceError) + GetRoleWithPermissions(id string) (*RoleWithPermissions, *serviceerror.ServiceError) + UpdateRoleWithPermissions(id string, role RoleUpdateDetail) (*RoleWithPermissions, *serviceerror.ServiceError) + DeleteRole(id string) *serviceerror.ServiceError + GetRoleAssignments(id string, limit, offset int, + includeDisplay bool) (*AssignmentList, *serviceerror.ServiceError) + AddAssignments(id string, assignments []RoleAssignment) *serviceerror.ServiceError + RemoveAssignments(id string, assignments []RoleAssignment) *serviceerror.ServiceError + GetAuthorizedPermissions( + userID string, groups []string, requestedPermissions []string, + ) ([]string, *serviceerror.ServiceError) +} + +// roleService is the default implementation of the RoleServiceInterface. +type roleService struct { + roleStore roleStoreInterface + userService user.UserServiceInterface + groupService group.GroupServiceInterface + ouService oupkg.OrganizationUnitServiceInterface +} + +// newRoleService creates a new instance of RoleService with injected dependencies. +func newRoleService( + roleStore roleStoreInterface, + userService user.UserServiceInterface, + groupService group.GroupServiceInterface, + ouService oupkg.OrganizationUnitServiceInterface, +) RoleServiceInterface { + return &roleService{ + roleStore: roleStore, + userService: userService, + groupService: groupService, + ouService: ouService, + } +} + +// GetRoleList retrieves a list of roles. +func (rs *roleService) GetRoleList(limit, offset int) (*RoleList, *serviceerror.ServiceError) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + + if err := validatePaginationParams(limit, offset); err != nil { + return nil, err + } + + totalCount, err := rs.roleStore.GetRoleListCount() + if err != nil { + logger.Error("Failed to get role count", log.Error(err)) + return nil, &ErrorInternalServerError + } + + roles, err := rs.roleStore.GetRoleList(limit, offset) + if err != nil { + logger.Error("Failed to list roles", log.Error(err)) + return nil, &ErrorInternalServerError + } + + response := &RoleList{ + TotalResults: totalCount, + Roles: roles, + StartIndex: offset + 1, + Count: len(roles), + Links: buildPaginationLinks("/roles", limit, offset, totalCount), + } + + return response, nil +} + +// CreateRole creates a new role. +func (rs *roleService) CreateRole( + role RoleCreationDetail, +) (*RoleWithPermissionsAndAssignments, *serviceerror.ServiceError) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Creating role", log.String("name", role.Name)) + + if err := rs.validateCreateRoleRequest(role); err != nil { + return nil, err + } + + // Validate assignment IDs early to avoid unnecessary database operations + if len(role.Assignments) > 0 { + if err := rs.validateAssignmentIDs(role.Assignments); err != nil { + return nil, err + } + } + + // Validate organization unit exists using OU service + _, svcErr := rs.ouService.GetOrganizationUnit(role.OrganizationUnitID) + if svcErr != nil { + if svcErr.Code == oupkg.ErrorOrganizationUnitNotFound.Code { + logger.Debug("Organization unit not found", log.String("ouID", role.OrganizationUnitID)) + return nil, &ErrorOrganizationUnitNotFound + } + logger.Error("Failed to validate organization unit", log.String("error", svcErr.Error)) + return nil, &ErrorInternalServerError + } + + // Check if role name already exists in the organization unit + nameExists, err := rs.roleStore.CheckRoleNameExists(role.OrganizationUnitID, role.Name) + if err != nil { + logger.Error("Failed to check role name existence", log.Error(err)) + return nil, &ErrorInternalServerError + } + if nameExists { + logger.Debug("Role name already exists in organization unit", + log.String("name", role.Name), log.String("ouID", role.OrganizationUnitID)) + return nil, &ErrorRoleNameConflict + } + + id := utils.GenerateUUID() + if err := rs.roleStore.CreateRole(id, role); err != nil { + logger.Error("Failed to create role", log.Error(err)) + return nil, &ErrorInternalServerError + } + + serviceRole := &RoleWithPermissionsAndAssignments{ + ID: id, + Name: role.Name, + Description: role.Description, + OrganizationUnitID: role.OrganizationUnitID, + Permissions: role.Permissions, + Assignments: role.Assignments, + } + + logger.Debug("Successfully created role", log.String("id", id), log.String("name", role.Name)) + return serviceRole, nil +} + +// GetRoleWithPermissions retrieves a specific role by its id. +func (rs *roleService) GetRoleWithPermissions(id string) (*RoleWithPermissions, *serviceerror.ServiceError) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Retrieving role", log.String("id", id)) + + if id == "" { + return nil, &ErrorMissingRoleID + } + + role, err := rs.roleStore.GetRole(id) + if err != nil { + if errors.Is(err, ErrRoleNotFound) { + logger.Debug("Role not found", log.String("id", id)) + return nil, &ErrorRoleNotFound + } + logger.Error("Failed to retrieve role", log.String("id", id), log.Error(err)) + return nil, &ErrorInternalServerError + } + + logger.Debug("Successfully retrieved role", log.String("id", role.ID), log.String("name", role.Name)) + return &role, nil +} + +// UpdateRole updates an existing role. +func (rs *roleService) UpdateRoleWithPermissions( + id string, role RoleUpdateDetail) (*RoleWithPermissions, *serviceerror.ServiceError) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Updating role", log.String("id", id), log.String("name", role.Name)) + + if id == "" { + return nil, &ErrorMissingRoleID + } + + if err := rs.validateUpdateRoleRequest(role); err != nil { + return nil, err + } + + exists, err := rs.roleStore.IsRoleExist(id) + if err != nil { + logger.Error("Failed to check role existence", log.String("id", id), log.Error(err)) + return nil, &ErrorInternalServerError + } + if !exists { + logger.Debug("Role not found", log.String("id", id)) + return nil, &ErrorRoleNotFound + } + + // Validate organization unit exists using OU service + _, svcErr := rs.ouService.GetOrganizationUnit(role.OrganizationUnitID) + if svcErr != nil { + if svcErr.Code == oupkg.ErrorOrganizationUnitNotFound.Code { + logger.Debug("Organization unit not found", log.String("ouID", role.OrganizationUnitID)) + return nil, &ErrorOrganizationUnitNotFound + } + logger.Error("Failed to validate organization unit", log.String("error", svcErr.Error)) + return nil, &ErrorInternalServerError + } + + // Check if role name already exists in the organization unit (excluding the current role) + nameExists, err := rs.roleStore.CheckRoleNameExistsExcludingID(role.OrganizationUnitID, role.Name, id) + if err != nil { + logger.Error("Failed to check role name existence", log.Error(err)) + return nil, &ErrorInternalServerError + } + if nameExists { + logger.Debug("Role name already exists in organization unit", + log.String("name", role.Name), log.String("ouID", role.OrganizationUnitID)) + return nil, &ErrorRoleNameConflict + } + + if err := rs.roleStore.UpdateRole(id, role); err != nil { + logger.Error("Failed to update role", log.Error(err)) + return nil, &ErrorInternalServerError + } + + logger.Debug("Successfully updated role", log.String("id", id), log.String("name", role.Name)) + return &RoleWithPermissions{ + ID: id, + Name: role.Name, + Description: role.Description, + OrganizationUnitID: role.OrganizationUnitID, + Permissions: role.Permissions, + }, nil +} + +// DeleteRole delete the specified role by its id. +func (rs *roleService) DeleteRole(id string) *serviceerror.ServiceError { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Deleting role", log.String("id", id)) + + if id == "" { + return &ErrorMissingRoleID + } + + exists, err := rs.roleStore.IsRoleExist(id) + if err != nil { + logger.Error("Failed to check role existence", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + if !exists { + logger.Debug("Role not found", log.String("id", id)) + return nil + } + + // Check if role has any assignments before deleting + assignmentCount, err := rs.roleStore.GetRoleAssignmentsCount(id) + if err != nil { + logger.Error("Failed to get role assignments count", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + + if assignmentCount > 0 { + logger.Debug("Cannot delete role with active assignments", + log.String("id", id), log.Int("assignmentCount", assignmentCount)) + return &ErrorCannotDeleteRole + } + + if err := rs.roleStore.DeleteRole(id); err != nil { + logger.Error("Failed to delete role", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + + logger.Debug("Successfully deleted role", log.String("id", id)) + return nil +} + +// GetRoleAssignments retrieves assignments for a role with pagination. +func (rs *roleService) GetRoleAssignments(id string, limit, offset int, + includeDisplay bool) (*AssignmentList, *serviceerror.ServiceError) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + + if err := validatePaginationParams(limit, offset); err != nil { + return nil, err + } + + if id == "" { + return nil, &ErrorMissingRoleID + } + + exists, err := rs.roleStore.IsRoleExist(id) + if err != nil { + logger.Error("Failed to check role existence", log.String("id", id), log.Error(err)) + return nil, &ErrorInternalServerError + } + if !exists { + logger.Debug("Role not found", log.String("id", id)) + return nil, &ErrorRoleNotFound + } + + totalCount, err := rs.roleStore.GetRoleAssignmentsCount(id) + if err != nil { + logger.Error("Failed to get role assignments count", log.String("id", id), log.Error(err)) + return nil, &ErrorInternalServerError + } + + assignments, err := rs.roleStore.GetRoleAssignments(id, limit, offset) + if err != nil { + logger.Error("Failed to get role assignments", log.String("id", id), log.Error(err)) + return nil, &ErrorInternalServerError + } + + // Convert to service layer assignments + serviceAssignments := make([]RoleAssignmentWithDisplay, len(assignments)) + + for i := range assignments { + // Populate display names if requested + displayName := "" + if includeDisplay { + displayName, err = rs.getDisplayNameForAssignment(&assignments[i]) + if err != nil { + logger.Warn("Failed to get display name for assignment", + log.String("assignmentID", assignments[i].ID), + log.String("assignmentType", string(assignments[i].Type)), + log.Error(err)) + // Continue with empty display name rather than failing the entire request + displayName = "" + } + } + serviceAssignments[i].ID = assignments[i].ID + serviceAssignments[i].Type = assignments[i].Type + serviceAssignments[i].Display = displayName + } + baseURL := fmt.Sprintf("/roles/%s/assignments", id) + links := buildPaginationLinks(baseURL, limit, offset, totalCount) + + response := &AssignmentList{ + TotalResults: totalCount, + Assignments: serviceAssignments, + StartIndex: offset + 1, + Count: len(serviceAssignments), + Links: links, + } + + return response, nil +} + +// AddAssignments adds assignments to a role. +func (rs *roleService) AddAssignments(id string, assignments []RoleAssignment) *serviceerror.ServiceError { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Adding assignments to role", log.String("id", id)) + + if id == "" { + return &ErrorMissingRoleID + } + + if err := rs.validateAssignmentsRequest(assignments); err != nil { + return err + } + + exists, err := rs.roleStore.IsRoleExist(id) + if err != nil { + logger.Error("Failed to check role existence", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + if !exists { + logger.Debug("Role not found", log.String("id", id)) + return &ErrorRoleNotFound + } + + if err := rs.validateAssignmentIDs(assignments); err != nil { + return err + } + + if err := rs.roleStore.AddAssignments(id, assignments); err != nil { + logger.Error("Failed to add assignments to role", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + + logger.Debug("Successfully added assignments to role", log.String("id", id)) + return nil +} + +// RemoveAssignments removes assignments from a role. +func (rs *roleService) RemoveAssignments(id string, assignments []RoleAssignment) *serviceerror.ServiceError { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Removing assignments from role", log.String("id", id)) + + if id == "" { + return &ErrorMissingRoleID + } + + if err := rs.validateAssignmentsRequest(assignments); err != nil { + return err + } + + exists, err := rs.roleStore.IsRoleExist(id) + if err != nil { + logger.Error("Failed to check role existence", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + if !exists { + logger.Debug("Role not found", log.String("id", id)) + return &ErrorRoleNotFound + } + + if err := rs.roleStore.RemoveAssignments(id, assignments); err != nil { + logger.Error("Failed to remove assignments from role", log.String("id", id), log.Error(err)) + return &ErrorInternalServerError + } + + logger.Debug("Successfully removed assignments from role", log.String("id", id)) + return nil +} + +// GetAuthorizedPermissions checks which of the requested permissions are authorized for the user based on their roles. +func (rs *roleService) GetAuthorizedPermissions( + userID string, groups []string, requestedPermissions []string, +) ([]string, *serviceerror.ServiceError) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger.Debug("Authorizing permissions", log.String("userID", userID), log.Int("groupCount", len(groups))) + + // Handle nil groups slice + if groups == nil { + groups = []string{} + } + + // Validate that at least userID or groups is provided + if userID == "" && len(groups) == 0 { + return nil, &ErrorMissingUserOrGroups + } + + // Return empty list if no permissions requested + if len(requestedPermissions) == 0 { + return []string{}, nil + } + + // Get authorized permissions from store + authorizedPermissions, err := rs.roleStore.GetAuthorizedPermissions(userID, groups, requestedPermissions) + if err != nil { + logger.Error("Failed to get authorized permissions", + log.String("userID", userID), + log.Int("groupCount", len(groups)), + log.Error(err)) + return nil, &ErrorInternalServerError + } + + logger.Debug("Retrieved authorized permissions", + log.String("userID", userID), + log.Int("groupCount", len(groups)), + log.Int("requestedCount", len(requestedPermissions)), + log.Int("authorizedCount", len(authorizedPermissions))) + + return authorizedPermissions, nil +} + +// validateCreateRoleRequest validates the create role request. +func (rs *roleService) validateCreateRoleRequest(role RoleCreationDetail) *serviceerror.ServiceError { + if role.Name == "" { + return &ErrorInvalidRequestFormat + } + + if role.OrganizationUnitID == "" { + return &ErrorInvalidRequestFormat + } + + if len(role.Assignments) > 0 { + if err := rs.validateAssignmentsRequest(role.Assignments); err != nil { + return err + } + } + + return nil +} + +// validateUpdateRoleRequest validates the update role request. +func (rs *roleService) validateUpdateRoleRequest(request RoleUpdateDetail) *serviceerror.ServiceError { + if request.Name == "" { + return &ErrorInvalidRequestFormat + } + + if request.OrganizationUnitID == "" { + return &ErrorInvalidRequestFormat + } + + return nil +} + +// validateAssignmentsRequest validates the assignments request. +func (rs *roleService) validateAssignmentsRequest(assignments []RoleAssignment) *serviceerror.ServiceError { + if len(assignments) == 0 { + return &ErrorEmptyAssignments + } + + for _, assignment := range assignments { + if assignment.Type != AssigneeTypeUser && assignment.Type != AssigneeTypeGroup { + return &ErrorInvalidRequestFormat + } + if assignment.ID == "" { + return &ErrorInvalidRequestFormat + } + } + + return nil +} + +// validateAssignmentIDs validates that all provided assignment IDs exist. +func (rs *roleService) validateAssignmentIDs(assignments []RoleAssignment) *serviceerror.ServiceError { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + + var userIDs []string + var groupIDs []string + + // Collect user and group IDs + for _, assignment := range assignments { + switch assignment.Type { + case AssigneeTypeUser: + userIDs = append(userIDs, assignment.ID) + case AssigneeTypeGroup: + groupIDs = append(groupIDs, assignment.ID) + } + } + + // Deduplicate IDs + userIDs = utils.UniqueStrings(userIDs) + groupIDs = utils.UniqueStrings(groupIDs) + + // Validate user IDs using user service + if len(userIDs) > 0 { + invalidUserIDs, svcErr := rs.userService.ValidateUserIDs(userIDs) + if svcErr != nil { + logger.Error("Failed to validate user IDs", log.String("error", svcErr.Error), + log.String("code", svcErr.Code)) + return &ErrorInternalServerError + } + + if len(invalidUserIDs) > 0 { + logger.Debug("Invalid user IDs found", log.Any("invalidUserIDs", invalidUserIDs)) + return &ErrorInvalidAssignmentID + } + } + + // Validate group IDs using group service + if len(groupIDs) > 0 { + if err := rs.groupService.ValidateGroupIDs(groupIDs); err != nil { + if err.Code == group.ErrorInvalidGroupMemberID.Code { + logger.Debug("Invalid group member IDs found") + return &ErrorInvalidAssignmentID + } + logger.Error("Failed to validate group IDs", log.String("error", err.Error)) + return &ErrorInternalServerError + } + } + + return nil +} + +// validatePaginationParams validates pagination parameters. +func validatePaginationParams(limit, offset int) *serviceerror.ServiceError { + if limit < 1 || limit > serverconst.MaxPageSize { + return &ErrorInvalidLimit + } + if offset < 0 { + return &ErrorInvalidOffset + } + return nil +} + +// buildPaginationLinks builds pagination links for the response. +func buildPaginationLinks(base string, limit, offset, totalCount int) []Link { + links := make([]Link, 0) + + if offset > 0 { + links = append(links, Link{ + Href: fmt.Sprintf("%s?offset=0&limit=%d", base, limit), + Rel: "first", + }) + + prevOffset := offset - limit + if prevOffset < 0 { + prevOffset = 0 + } + links = append(links, Link{ + Href: fmt.Sprintf("%s?offset=%d&limit=%d", base, prevOffset, limit), + Rel: "prev", + }) + } + + if offset+limit < totalCount { + nextOffset := offset + limit + links = append(links, Link{ + Href: fmt.Sprintf("%s?offset=%d&limit=%d", base, nextOffset, limit), + Rel: "next", + }) + } + + lastPageOffset := ((totalCount - 1) / limit) * limit + if offset < lastPageOffset { + links = append(links, Link{ + Href: fmt.Sprintf("%s?offset=%d&limit=%d", base, lastPageOffset, limit), + Rel: "last", + }) + } + + return links +} + +// getDisplayNameForAssignment retrieves the display name for a user or group assignment. +func (rs *roleService) getDisplayNameForAssignment(assignment *RoleAssignment) (string, error) { + switch assignment.Type { + case AssigneeTypeUser: + userResp, svcErr := rs.userService.GetUser(assignment.ID) + if svcErr != nil { + return "", fmt.Errorf("failed to get user: %w", errors.New(svcErr.Error)) + } + // Return user ID as display name (since User doesn't have a username field) + return userResp.ID, nil + + case AssigneeTypeGroup: + groupResp, svcErr := rs.groupService.GetGroup(assignment.ID) + if svcErr != nil { + return "", fmt.Errorf("failed to get group: %w", errors.New(svcErr.Error)) + } + // Return group name as display name + return groupResp.Name, nil + + default: + return "", fmt.Errorf("unknown assignment type: %s", assignment.Type) + } +} diff --git a/backend/internal/role/service_test.go b/backend/internal/role/service_test.go new file mode 100644 index 000000000..b1a1d7805 --- /dev/null +++ b/backend/internal/role/service_test.go @@ -0,0 +1,1209 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/asgardeo/thunder/internal/group" + oupkg "github.com/asgardeo/thunder/internal/ou" + serverconst "github.com/asgardeo/thunder/internal/system/constants" + "github.com/asgardeo/thunder/internal/system/error/serviceerror" + "github.com/asgardeo/thunder/internal/user" + "github.com/asgardeo/thunder/tests/mocks/groupmock" + "github.com/asgardeo/thunder/tests/mocks/oumock" + "github.com/asgardeo/thunder/tests/mocks/usermock" +) + +const ( + testUserID1 = "user1" +) + +// Test Suite +type RoleServiceTestSuite struct { + suite.Suite + mockStore *roleStoreInterfaceMock + mockUserService *usermock.UserServiceInterfaceMock + mockGroupService *groupmock.GroupServiceInterfaceMock + mockOUService *oumock.OrganizationUnitServiceInterfaceMock + service RoleServiceInterface +} + +func TestRoleServiceTestSuite(t *testing.T) { + suite.Run(t, new(RoleServiceTestSuite)) +} + +func (suite *RoleServiceTestSuite) SetupTest() { + suite.mockStore = newRoleStoreInterfaceMock(suite.T()) + suite.mockUserService = usermock.NewUserServiceInterfaceMock(suite.T()) + suite.mockGroupService = groupmock.NewGroupServiceInterfaceMock(suite.T()) + suite.mockOUService = oumock.NewOrganizationUnitServiceInterfaceMock(suite.T()) + suite.service = newRoleService( + suite.mockStore, + suite.mockUserService, + suite.mockGroupService, + suite.mockOUService, + ) +} + +// GetRoleList Tests +func (suite *RoleServiceTestSuite) TestGetRoleList_Success() { + expectedRoles := []Role{ + {ID: "role1", Name: "Admin", OrganizationUnitID: "ou1"}, + {ID: "role2", Name: "User", OrganizationUnitID: "ou1"}, + } + + suite.mockStore.On("GetRoleListCount").Return(2, nil) + suite.mockStore.On("GetRoleList", 10, 0).Return(expectedRoles, nil) + + result, err := suite.service.GetRoleList(10, 0) + + suite.Nil(err) + suite.NotNil(result) + suite.Equal(2, result.TotalResults) + suite.Equal(2, result.Count) + suite.Equal(1, result.StartIndex) + suite.Equal(2, len(result.Roles)) + suite.Equal("role1", result.Roles[0].ID) + suite.Equal("Admin", result.Roles[0].Name) + suite.Equal("role2", result.Roles[1].ID) + suite.Equal("User", result.Roles[1].Name) +} + +func (suite *RoleServiceTestSuite) TestGetRoleList_InvalidPagination() { + testCases := []struct { + name string + limit int + offset int + errCode string + }{ + {"InvalidLimit_Zero", 0, 0, ErrorInvalidLimit.Code}, + {"InvalidLimit_TooLarge", serverconst.MaxPageSize + 1, 0, ErrorInvalidLimit.Code}, + {"InvalidOffset_Negative", 10, -1, ErrorInvalidOffset.Code}, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + result, err := suite.service.GetRoleList(tc.limit, tc.offset) + suite.Nil(result) + suite.NotNil(err) + suite.Equal(tc.errCode, err.Code) + }) + } +} + +func (suite *RoleServiceTestSuite) TestGetRoleList_StoreErrors() { + testCases := []struct { + name string + mockSetup func() + }{ + { + name: "CountError", + mockSetup: func() { + suite.mockStore.On("GetRoleListCount").Return(0, errors.New("database error")).Once() + }, + }, + { + name: "GetListError", + mockSetup: func() { + suite.mockStore.On("GetRoleListCount").Return(10, nil).Once() + suite.mockStore.On("GetRoleList", 10, 0).Return([]Role{}, errors.New("database error")).Once() + }, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + tc.mockSetup() + + result, err := suite.service.GetRoleList(10, 0) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) + }) + } +} + +// CreateRole Tests +func (suite *RoleServiceTestSuite) TestCreateRole_Success() { + request := RoleCreationDetail{ + Name: "Test Role", + Description: "Test Description", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1", "perm2"}, + Assignments: []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + + ou := oupkg.OrganizationUnit{ID: "ou1", Name: "Test OU"} + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExists", "ou1", "Test Role").Return(false, nil) + suite.mockUserService.On("ValidateUserIDs", []string{testUserID1}).Return([]string{}, nil) + suite.mockStore.On("CreateRole", mock.AnythingOfType("string"), mock.AnythingOfType("RoleCreationDetail")).Return(nil) + + result, err := suite.service.CreateRole(request) + + suite.Nil(err) + suite.NotNil(result) + suite.Equal("Test Role", result.Name) + suite.Equal("Test Description", result.Description) + suite.Equal("ou1", result.OrganizationUnitID) + suite.Equal(2, len(result.Permissions)) +} + +func (suite *RoleServiceTestSuite) TestCreateRole_ValidationErrors() { + testCases := []struct { + name string + request RoleCreationDetail + errCode string + }{ + { + name: "MissingName", + request: RoleCreationDetail{OrganizationUnitID: "ou1", Permissions: []string{"perm1"}}, + errCode: ErrorInvalidRequestFormat.Code, + }, + { + name: "MissingOrgUnit", + request: RoleCreationDetail{Name: "Role", Permissions: []string{"perm1"}}, + errCode: ErrorInvalidRequestFormat.Code, + }, + { + name: "InvalidAssignmentType", + request: RoleCreationDetail{ + Name: "Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + Assignments: []RoleAssignment{{ID: testUserID1, Type: "invalid"}}, + }, + errCode: ErrorInvalidRequestFormat.Code, + }, + { + name: "EmptyAssignmentID", + request: RoleCreationDetail{ + Name: "Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + Assignments: []RoleAssignment{{ID: "", Type: AssigneeTypeUser}}, + }, + errCode: ErrorInvalidRequestFormat.Code, + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + result, err := suite.service.CreateRole(tc.request) + suite.Nil(result) + suite.NotNil(err) + suite.Equal(tc.errCode, err.Code) + }) + } +} + +func (suite *RoleServiceTestSuite) TestCreateRole_OrganizationUnitNotFound() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "nonexistent", + Permissions: []string{"perm1"}, + } + + suite.mockOUService.On("GetOrganizationUnit", "nonexistent"). + Return(oupkg.OrganizationUnit{}, &oupkg.ErrorOrganizationUnitNotFound) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorOrganizationUnitNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestCreateRole_InvalidUserID() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + Assignments: []RoleAssignment{{ID: "invalid_user", Type: AssigneeTypeUser}}, + } + + // Assignment validation now happens before OU and name checks + suite.mockUserService.On("ValidateUserIDs", []string{"invalid_user"}).Return([]string{"invalid_user"}, nil) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInvalidAssignmentID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestCreateRole_InvalidGroupID() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + Assignments: []RoleAssignment{{ID: "invalid_group", Type: AssigneeTypeGroup}}, + } + + // Assignment validation now happens before OU and name checks + suite.mockGroupService.On("ValidateGroupIDs", []string{"invalid_group"}).Return(&group.ErrorInvalidGroupMemberID) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInvalidAssignmentID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestCreateRole_StoreError() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExists", "ou1", "Test Role").Return(false, nil) + suite.mockStore.On("CreateRole", mock.AnythingOfType("string"), + mock.AnythingOfType("RoleCreationDetail")).Return(errors.New("database error")) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestCreateRole_NameConflict() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExists", "ou1", "Test Role").Return(true, nil) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorRoleNameConflict.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestCreateRole_CheckNameExistsError() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExists", "ou1", "Test Role").Return(false, errors.New("database error")) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +// GetRoleWithPermissions Tests +func (suite *RoleServiceTestSuite) TestGetRole_Success() { + expectedRole := RoleWithPermissions{ + ID: "role1", + Name: "Admin", + Description: "Administrator role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1", "perm2"}, + } + + suite.mockStore.On("GetRole", "role1").Return(expectedRole, nil) + + result, err := suite.service.GetRoleWithPermissions("role1") + + suite.Nil(err) + suite.NotNil(result) + suite.Equal(expectedRole.ID, result.ID) + suite.Equal(expectedRole.Name, result.Name) +} + +func (suite *RoleServiceTestSuite) TestGetRole_MissingID() { + result, err := suite.service.GetRoleWithPermissions("") + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorMissingRoleID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRole_NotFound() { + suite.mockStore.On("GetRole", "nonexistent").Return(RoleWithPermissions{}, ErrRoleNotFound) + + result, err := suite.service.GetRoleWithPermissions("nonexistent") + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorRoleNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRole_StoreError() { + suite.mockStore.On("GetRole", "role1").Return(RoleWithPermissions{}, errors.New("database error")) + + result, err := suite.service.GetRoleWithPermissions("role1") + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +// UpdateRole Tests +func (suite *RoleServiceTestSuite) TestUpdateRole_MissingRoleID() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + result, err := suite.service.UpdateRoleWithPermissions("", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorMissingRoleID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_ValidationErrors() { + testCases := []struct { + name string + request RoleUpdateDetail + errCode string + }{ + { + name: "MissingName", + request: RoleUpdateDetail{OrganizationUnitID: "ou1", Permissions: []string{"perm1"}}, + errCode: ErrorInvalidRequestFormat.Code, + }, + { + name: "MissingOrgUnit", + request: RoleUpdateDetail{Name: "Role", Permissions: []string{"perm1"}}, + errCode: ErrorInvalidRequestFormat.Code, + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + result, err := suite.service.UpdateRoleWithPermissions("role1", tc.request) + suite.Nil(result) + suite.NotNil(err) + suite.Equal(tc.errCode, err.Code) + }) + } +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_GetRoleError() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(false, errors.New("database error")) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_OUNotFound() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "nonexistent_ou", + Permissions: []string{"perm1"}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockOUService.On("GetOrganizationUnit", "nonexistent_ou"). + Return(oupkg.OrganizationUnit{}, &oupkg.ErrorOrganizationUnitNotFound) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorOrganizationUnitNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_OUServiceError() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockOUService.On("GetOrganizationUnit", "ou1"). + Return(oupkg.OrganizationUnit{}, &serviceerror.ServiceError{Code: "INTERNAL_ERROR"}) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_UpdateStoreError() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExistsExcludingID", "ou1", "New Name", "role1").Return(false, nil) + suite.mockStore.On("UpdateRole", mock.AnythingOfType("string"), + mock.AnythingOfType("RoleUpdateDetail")).Return(errors.New("update error")) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_Success() { + request := RoleUpdateDetail{ + Name: "New Name", + Description: "Updated description", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1", "perm2"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExistsExcludingID", "ou1", "New Name", "role1").Return(false, nil) + suite.mockStore.On("UpdateRole", mock.AnythingOfType("string"), mock.AnythingOfType("RoleUpdateDetail")).Return(nil) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(err) + suite.NotNil(result) + suite.Equal("New Name", result.Name) + suite.Equal("Updated description", result.Description) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_RoleNotFound() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockStore.On("IsRoleExist", "nonexistent").Return(false, nil) + + result, err := suite.service.UpdateRoleWithPermissions("nonexistent", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorRoleNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_NameConflict() { + request := RoleUpdateDetail{ + Name: "Conflicting Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExistsExcludingID", "ou1", "Conflicting Name", "role1").Return(true, nil) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorRoleNameConflict.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestUpdateRole_CheckNameExistsError() { + request := RoleUpdateDetail{ + Name: "New Name", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + ou := oupkg.OrganizationUnit{ID: "ou1"} + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockOUService.On("GetOrganizationUnit", "ou1").Return(ou, nil) + suite.mockStore.On("CheckRoleNameExistsExcludingID", "ou1", "New Name", "role1"). + Return(false, errors.New("database error")) + + result, err := suite.service.UpdateRoleWithPermissions("role1", request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +// DeleteRole Tests +func (suite *RoleServiceTestSuite) TestDeleteRole_Success() { + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(0, nil) + suite.mockStore.On("DeleteRole", "role1").Return(nil) + + err := suite.service.DeleteRole("role1") + + suite.Nil(err) +} + +func (suite *RoleServiceTestSuite) TestDeleteRole_WithAssignments() { + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(5, nil) + + err := suite.service.DeleteRole("role1") + + suite.NotNil(err) + suite.Equal(ErrorCannotDeleteRole.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestDeleteRole_NotFound_ReturnsNil() { + suite.mockStore.On("IsRoleExist", "nonexistent").Return(false, nil) + + err := suite.service.DeleteRole("nonexistent") + + suite.Nil(err) +} + +func (suite *RoleServiceTestSuite) TestDeleteRole_MissingID() { + err := suite.service.DeleteRole("") + + suite.NotNil(err) + suite.Equal(ErrorMissingRoleID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestDeleteRole_GetRoleError() { + suite.mockStore.On("IsRoleExist", "role1").Return(false, errors.New("database error")) + + err := suite.service.DeleteRole("role1") + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestDeleteRole_GetAssignmentsCountError() { + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(0, errors.New("database error")) + + err := suite.service.DeleteRole("role1") + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestDeleteRole_StoreError() { + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(0, nil) + suite.mockStore.On("DeleteRole", "role1").Return(errors.New("delete error")) + + err := suite.service.DeleteRole("role1") + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +// GetRoleAssignments Tests +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_Success() { + expectedAssignments := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + {ID: "group1", Type: AssigneeTypeGroup}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(2, nil) + suite.mockStore.On("GetRoleAssignments", "role1", 10, 0).Return(expectedAssignments, nil) + + result, err := suite.service.GetRoleAssignments("role1", 10, 0, false) + + suite.Nil(err) + suite.NotNil(result) + suite.Equal(2, result.TotalResults) + suite.Equal(2, result.Count) + suite.Equal(2, len(result.Assignments)) + suite.Equal("user1", result.Assignments[0].ID) + suite.Equal(AssigneeTypeUser, result.Assignments[0].Type) + suite.Equal("group1", result.Assignments[1].ID) + suite.Equal(AssigneeTypeGroup, result.Assignments[1].Type) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_MissingID() { + result, err := suite.service.GetRoleAssignments("", 10, 0, false) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorMissingRoleID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_InvalidPagination() { + result, err := suite.service.GetRoleAssignments("role1", 0, 0, false) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInvalidLimit.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_RoleNotFound() { + suite.mockStore.On("IsRoleExist", "nonexistent").Return(false, nil) + + result, err := suite.service.GetRoleAssignments("nonexistent", 10, 0, false) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorRoleNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_GetRoleError() { + suite.mockStore.On("IsRoleExist", "role1").Return(false, errors.New("database error")) + + result, err := suite.service.GetRoleAssignments("role1", 10, 0, false) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_CountError() { + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(0, errors.New("count error")) + + result, err := suite.service.GetRoleAssignments("role1", 10, 0, false) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_GetListError() { + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(2, nil) + suite.mockStore.On("GetRoleAssignments", "role1", 10, 0).Return([]RoleAssignment{}, errors.New("list error")) + + result, err := suite.service.GetRoleAssignments("role1", 10, 0, false) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_WithDisplay_Success() { + expectedAssignments := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + {ID: "group1", Type: AssigneeTypeGroup}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(2, nil) + suite.mockStore.On("GetRoleAssignments", "role1", 10, 0).Return(expectedAssignments, nil) + suite.mockUserService.On("GetUser", testUserID1).Return(&user.User{ID: testUserID1}, nil).Once() + suite.mockGroupService.On("GetGroup", "group1").Return(&group.Group{Name: "Test Group"}, nil).Once() + + result, err := suite.service.GetRoleAssignments("role1", 10, 0, true) + + suite.Nil(err) + suite.NotNil(result) + suite.Equal(2, result.TotalResults) + suite.Equal(2, result.Count) + suite.Equal(testUserID1, result.Assignments[0].Display) + suite.Equal("Test Group", result.Assignments[1].Display) +} + +func (suite *RoleServiceTestSuite) TestGetRoleAssignments_WithDisplay_FetchErrors() { + testCases := []struct { + name string + assignment RoleAssignment + setupMock func() + expectedDisplay string + }{ + { + name: "User fetch error", + assignment: RoleAssignment{ID: testUserID1, Type: AssigneeTypeUser}, + setupMock: func() { + suite.mockUserService.On("GetUser", testUserID1). + Return(nil, &serviceerror.ServiceError{Code: "USER_NOT_FOUND"}).Once() + }, + expectedDisplay: "", + }, + { + name: "Group fetch error", + assignment: RoleAssignment{ID: "group1", Type: AssigneeTypeGroup}, + setupMock: func() { + suite.mockGroupService.On("GetGroup", "group1"). + Return(nil, &serviceerror.ServiceError{Code: "GROUP_NOT_FOUND"}).Once() + }, + expectedDisplay: "", + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + expectedAssignments := []RoleAssignment{tc.assignment} + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil).Once() + suite.mockStore.On("GetRoleAssignmentsCount", "role1").Return(1, nil).Once() + suite.mockStore.On("GetRoleAssignments", "role1", 10, 0). + Return(expectedAssignments, nil).Once() + tc.setupMock() + + result, err := suite.service.GetRoleAssignments("role1", 10, 0, true) + + // Should succeed but with empty display name on error + suite.Nil(err) + suite.NotNil(result) + suite.Equal(1, result.TotalResults) + suite.Equal(1, result.Count) + suite.Equal(tc.expectedDisplay, result.Assignments[0].Display) + }) + } +} + +// AddAssignments Tests +func (suite *RoleServiceTestSuite) TestAddAssignments_MissingRoleID() { + request := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + } + + err := suite.service.AddAssignments("", request) + + suite.NotNil(err) + suite.Equal(ErrorMissingRoleID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestAddAssignments_EmptyAssignments() { + request := []RoleAssignment{} + + err := suite.service.AddAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorEmptyAssignments.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestAddAssignments_InvalidAssignmentFormat() { + testCases := []struct { + name string + assignment RoleAssignment + }{ + { + name: "InvalidType", + assignment: RoleAssignment{ID: testUserID1, Type: "invalid_type"}, + }, + { + name: "EmptyID", + assignment: RoleAssignment{ID: "", Type: AssigneeTypeUser}, + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + request := []RoleAssignment{ + tc.assignment, + } + + err := suite.service.AddAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorInvalidRequestFormat.Code, err.Code) + }) + } +} + +func (suite *RoleServiceTestSuite) TestAddAssignments_RoleNotFound() { + request := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "nonexistent").Return(false, nil) + + err := suite.service.AddAssignments("nonexistent", request) + + suite.NotNil(err) + suite.Equal(ErrorRoleNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestAddAssignments_GetRoleError() { + request := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(false, errors.New("database error")) + + err := suite.service.AddAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestAddAssignments_StoreError() { + request := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockUserService.On("ValidateUserIDs", []string{testUserID1}).Return([]string{}, nil) + suite.mockStore.On("AddAssignments", "role1", request).Return(errors.New("store error")) + + err := suite.service.AddAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestAddAssignments_Success() { + request := []RoleAssignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockUserService.On("ValidateUserIDs", []string{testUserID1}).Return([]string{}, nil) + suite.mockStore.On("AddAssignments", "role1", request).Return(nil) + + err := suite.service.AddAssignments("role1", request) + + suite.Nil(err) +} + +// RemoveAssignments Tests +func (suite *RoleServiceTestSuite) TestRemoveAssignments_MissingRoleID() { + request := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + err := suite.service.RemoveAssignments("", request) + + suite.NotNil(err) + suite.Equal(ErrorMissingRoleID.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestRemoveAssignments_EmptyAssignments() { + request := []RoleAssignment{} + + err := suite.service.RemoveAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorEmptyAssignments.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestRemoveAssignments_RoleNotFound() { + request := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "nonexistent").Return(false, nil) + + err := suite.service.RemoveAssignments("nonexistent", request) + + suite.NotNil(err) + suite.Equal(ErrorRoleNotFound.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestRemoveAssignments_GetRoleError() { + request := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(false, errors.New("database error")) + + err := suite.service.RemoveAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestRemoveAssignments_StoreError() { + request := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("RemoveAssignments", "role1", request).Return(errors.New("store error")) + + err := suite.service.RemoveAssignments("role1", request) + + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestRemoveAssignments_Success() { + request := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + suite.mockStore.On("IsRoleExist", "role1").Return(true, nil) + suite.mockStore.On("RemoveAssignments", "role1", request).Return(nil) + + err := suite.service.RemoveAssignments("role1", request) + + suite.Nil(err) +} + +// validateAssignmentIDs Tests +func (suite *RoleServiceTestSuite) TestValidateAssignmentIDs_UserServiceError() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + Assignments: []RoleAssignment{{ID: "user1", Type: AssigneeTypeUser}}, + } + + // Assignment validation now happens before OU and name checks + suite.mockUserService.On("ValidateUserIDs", []string{"user1"}). + Return([]string{}, &serviceerror.ServiceError{Code: "INTERNAL_ERROR"}) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +func (suite *RoleServiceTestSuite) TestValidateAssignmentIDs_GroupServiceError() { + request := RoleCreationDetail{ + Name: "Test Role", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + Assignments: []RoleAssignment{{ID: "group1", Type: AssigneeTypeGroup}}, + } + + // Assignment validation now happens before OU and name checks + suite.mockGroupService.On("ValidateGroupIDs", []string{"group1"}). + Return(&serviceerror.ServiceError{Code: "INTERNAL_ERROR"}) + + result, err := suite.service.CreateRole(request) + + suite.Nil(result) + suite.NotNil(err) + suite.Equal(ErrorInternalServerError.Code, err.Code) +} + +// Utility functions tests +func (suite *RoleServiceTestSuite) TestBuildPaginationLinks() { + testCases := []struct { + name string + base string + limit int + offset int + totalCount int + expectFirst bool + expectPrev bool + expectNext bool + expectLast bool + }{ + { + name: "FirstPage", + base: "/roles", + limit: 10, + offset: 0, + totalCount: 30, + expectFirst: false, + expectPrev: false, + expectNext: true, + expectLast: true, + }, + { + name: "MiddlePage", + base: "/roles", + limit: 10, + offset: 10, + totalCount: 30, + expectFirst: true, + expectPrev: true, + expectNext: true, + expectLast: true, + }, + { + name: "LastPage", + base: "/roles", + limit: 10, + offset: 20, + totalCount: 30, + expectFirst: true, + expectPrev: true, + expectNext: false, + expectLast: false, + }, + { + name: "SinglePage", + base: "/roles", + limit: 10, + offset: 0, + totalCount: 5, + expectFirst: false, + expectPrev: false, + expectNext: false, + expectLast: false, + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + links := buildPaginationLinks(tc.base, tc.limit, tc.offset, tc.totalCount) + + hasFirst := false + hasPrev := false + hasNext := false + hasLast := false + + for _, link := range links { + switch link.Rel { + case "first": + hasFirst = true + case "prev": + hasPrev = true + case "next": + hasNext = true + case "last": + hasLast = true + } + } + + suite.Equal(tc.expectFirst, hasFirst, "first link mismatch") + suite.Equal(tc.expectPrev, hasPrev, "prev link mismatch") + suite.Equal(tc.expectNext, hasNext, "next link mismatch") + suite.Equal(tc.expectLast, hasLast, "last link mismatch") + }) + } +} + +// GetAuthorizedPermissions Tests - Consolidated for efficiency while maintaining coverage +func (suite *RoleServiceTestSuite) TestGetAuthorizedPermissions() { + testCases := []struct { + name string + userID string + groups []string + requestedPermissions []string + mockReturn []string + mockError error + expectedPermissions []string + expectedError *serviceerror.ServiceError + skipMock bool + }{ + { + name: "Success_UserAndGroups", + userID: testUserID1, + groups: []string{"group1", "group2"}, + requestedPermissions: []string{"perm1", "perm2", "perm3"}, + mockReturn: []string{"perm1", "perm3"}, + expectedPermissions: []string{"perm1", "perm3"}, + }, + { + name: "Success_UserOnly_NilGroupsNormalized", + userID: testUserID1, + groups: nil, // Tests both nil and empty groups normalization + requestedPermissions: []string{"perm1", "perm2"}, + mockReturn: []string{"perm1"}, + expectedPermissions: []string{"perm1"}, + }, + { + name: "Success_GroupsOnly", + userID: "", + groups: []string{"group1", "group2"}, + requestedPermissions: []string{"perm1", "perm2"}, + mockReturn: []string{"perm1"}, + expectedPermissions: []string{"perm1"}, + }, + { + name: "Success_NoAuthorizedPermissions", + userID: testUserID1, + groups: []string{"group1"}, + requestedPermissions: []string{"perm1", "perm2"}, + mockReturn: []string{}, // User has no permissions + expectedPermissions: []string{}, + }, + { + name: "Success_AllPermissionsAuthorized", + userID: testUserID1, + groups: []string{"group1"}, + requestedPermissions: []string{"perm1", "perm2"}, + mockReturn: []string{"perm1", "perm2"}, // All permissions authorized + expectedPermissions: []string{"perm1", "perm2"}, + }, + { + name: "EmptyAndNilRequestedPermissions_ReturnsEmpty", + userID: testUserID1, + groups: []string{"group1"}, + requestedPermissions: nil, // Also covers empty []string{} case + expectedPermissions: []string{}, + skipMock: true, // No store call for empty permissions + }, + { + name: "MissingUserAndGroups_Error", + userID: "", + groups: nil, // Covers both nil and empty cases + requestedPermissions: []string{"perm1", "perm2"}, + expectedError: &ErrorMissingUserOrGroups, + skipMock: true, + }, + { + name: "StoreError_ReturnsInternalError", + userID: testUserID1, + groups: []string{"group1"}, + requestedPermissions: []string{"perm1", "perm2"}, + mockError: errors.New("database error"), + expectedError: &ErrorInternalServerError, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + if !tc.skipMock { + normalizedGroups := tc.groups + if normalizedGroups == nil { + normalizedGroups = []string{} + } + suite.mockStore.On("GetAuthorizedPermissions", tc.userID, normalizedGroups, tc.requestedPermissions). + Return(tc.mockReturn, tc.mockError).Once() + } + + result, err := suite.service.GetAuthorizedPermissions(tc.userID, tc.groups, tc.requestedPermissions) + + if tc.expectedError != nil { + suite.NotNil(err) + suite.Equal(tc.expectedError.Code, err.Code) + suite.Nil(result) + } else { + suite.Nil(err) + suite.NotNil(result) + if len(tc.requestedPermissions) == 0 { + suite.Equal(0, len(result)) + } else { + suite.Equal(len(tc.expectedPermissions), len(result)) + suite.Equal(tc.expectedPermissions, result) + } + } + }) + } +} diff --git a/backend/internal/role/store.go b/backend/internal/role/store.go new file mode 100644 index 000000000..9f8620fab --- /dev/null +++ b/backend/internal/role/store.go @@ -0,0 +1,548 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "errors" + "fmt" + + dbmodel "github.com/asgardeo/thunder/internal/system/database/model" + "github.com/asgardeo/thunder/internal/system/database/provider" + "github.com/asgardeo/thunder/internal/system/log" +) + +const storeLoggerComponentName = "RoleStore" + +// roleStoreInterface defines the interface for role store operations. +type roleStoreInterface interface { + GetRoleListCount() (int, error) + GetRoleList(limit, offset int) ([]Role, error) + CreateRole(id string, role RoleCreationDetail) error + GetRole(id string) (RoleWithPermissions, error) + IsRoleExist(id string) (bool, error) + GetRoleAssignments(id string, limit, offset int) ([]RoleAssignment, error) + GetRoleAssignmentsCount(id string) (int, error) + UpdateRole(id string, role RoleUpdateDetail) error + DeleteRole(id string) error + AddAssignments(id string, assignments []RoleAssignment) error + RemoveAssignments(id string, assignments []RoleAssignment) error + CheckRoleNameExists(ouID, name string) (bool, error) + CheckRoleNameExistsExcludingID(ouID, name, excludeRoleID string) (bool, error) + GetAuthorizedPermissions(userID string, groupIDs []string, requestedPermissions []string) ([]string, error) +} + +// roleStore is the default implementation of roleStoreInterface. +type roleStore struct { + dbProvider provider.DBProviderInterface +} + +// newRoleStore creates a new instance of roleStore. +func newRoleStore() roleStoreInterface { + return &roleStore{ + dbProvider: provider.GetDBProvider(), + } +} + +// GetRoleListCount retrieves the total count of roles. +func (s *roleStore) GetRoleListCount() (int, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return 0, err + } + + countResults, err := dbClient.Query(queryGetRoleListCount) + if err != nil { + return 0, fmt.Errorf("failed to execute count query: %w", err) + } + + return parseCountResult(countResults) +} + +// GetRoleList retrieves roles with pagination. +func (s *roleStore) GetRoleList(limit, offset int) ([]Role, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return nil, err + } + + results, err := dbClient.Query(queryGetRoleList, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to execute role list query: %w", err) + } + + roles := make([]Role, 0) + for _, row := range results { + role, err := buildRoleBasicInfoFromResultRow(row) + if err != nil { + return nil, fmt.Errorf("failed to build role from result row: %w", err) + } + roles = append(roles, role) + } + + return roles, nil +} + +// CreateRole creates a new role in the database. +func (s *roleStore) CreateRole(id string, role RoleCreationDetail) error { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return err + } + + return s.executeInTransaction(dbClient, func(tx dbmodel.TxInterface) error { + _, err := tx.Exec( + queryCreateRole.Query, + id, + role.OrganizationUnitID, + role.Name, + role.Description, + ) + if err != nil { + return fmt.Errorf("failed to execute query: %w", err) + } + + if err := addPermissionsToRole(tx, id, role.Permissions); err != nil { + return err + } + + if err := addAssignmentsToRole(tx, id, role.Assignments); err != nil { + return err + } + + return nil + }) +} + +// GetRole retrieves a role by its id. +func (s *roleStore) GetRole(id string) (RoleWithPermissions, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return RoleWithPermissions{}, err + } + + results, err := dbClient.Query(queryGetRoleByID, id) + if err != nil { + return RoleWithPermissions{}, fmt.Errorf("failed to execute query: %w", err) + } + + if len(results) == 0 { + return RoleWithPermissions{}, ErrRoleNotFound + } + + if len(results) != 1 { + return RoleWithPermissions{}, fmt.Errorf("unexpected number of results: %d", len(results)) + } + + row := results[0] + roleBasicInfo, err := buildRoleBasicInfoFromResultRow(row) + if err != nil { + return RoleWithPermissions{}, err + } + + permissions, err := s.getRolePermissions(dbClient, id) + if err != nil { + return RoleWithPermissions{}, fmt.Errorf("failed to get role permissions: %w", err) + } + + return RoleWithPermissions{ + ID: roleBasicInfo.ID, + Name: roleBasicInfo.Name, + Description: roleBasicInfo.Description, + OrganizationUnitID: roleBasicInfo.OrganizationUnitID, + Permissions: permissions, + }, nil +} + +// IsRoleExist checks if a role exists by its ID without fetching its details. +func (s *roleStore) IsRoleExist(id string) (bool, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return false, err + } + + results, err := dbClient.Query(queryCheckRoleExists, id) + if err != nil { + return false, fmt.Errorf("failed to check role existence: %w", err) + } + + return parseBoolFromCount(results) +} + +// GetRoleAssignments retrieves assignments for a role with pagination. +func (s *roleStore) GetRoleAssignments(id string, limit, offset int) ([]RoleAssignment, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return nil, err + } + + results, err := dbClient.Query(queryGetRoleAssignments, id, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to get role assignments: %w", err) + } + + assignments := make([]RoleAssignment, 0) + for _, row := range results { + assigneeID, err := parseStringField(row, "assignee_id") + if err != nil { + return nil, err + } + assigneeType, err := parseStringField(row, "assignee_type") + if err != nil { + return nil, err + } + assignments = append(assignments, RoleAssignment{ + ID: assigneeID, + Type: AssigneeType(assigneeType), + }) + } + + return assignments, nil +} + +// GetRoleAssignmentsCount retrieves the total count of assignments for a role. +func (s *roleStore) GetRoleAssignmentsCount(id string) (int, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return 0, err + } + + countResults, err := dbClient.Query(queryGetRoleAssignmentsCount, id) + if err != nil { + return 0, fmt.Errorf("failed to get role assignments count: %w", err) + } + + return parseCountResult(countResults) +} + +// UpdateRole updates an existing role. +func (s *roleStore) UpdateRole(id string, role RoleUpdateDetail) error { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return err + } + + return s.executeInTransaction(dbClient, func(tx dbmodel.TxInterface) error { + result, err := tx.Exec( + queryUpdateRole.Query, + role.OrganizationUnitID, + role.Name, + role.Description, + id, + ) + if err != nil { + return fmt.Errorf("failed to execute query: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return ErrRoleNotFound + } + + if err := updateRolePermissions(tx, id, role.Permissions); err != nil { + return err + } + + return nil + }) +} + +// DeleteRole deletes a role. +func (s *roleStore) DeleteRole(id string) error { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, storeLoggerComponentName)) + + dbClient, err := s.getIdentityDBClient() + if err != nil { + return err + } + + rowsAffected, err := dbClient.Execute(queryDeleteRole, id) + if err != nil { + return fmt.Errorf("failed to execute query: %w", err) + } + + if rowsAffected == 0 { + logger.Debug("Role not found with id: " + id) + } + + return nil +} + +// AddAssignments adds assignments to a role. +func (s *roleStore) AddAssignments(id string, assignments []RoleAssignment) error { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return err + } + + return s.executeInTransaction(dbClient, func(tx dbmodel.TxInterface) error { + return addAssignmentsToRole(tx, id, assignments) + }) +} + +// RemoveAssignments removes assignments from a role. +func (s *roleStore) RemoveAssignments(id string, assignments []RoleAssignment) error { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return err + } + + return s.executeInTransaction(dbClient, func(tx dbmodel.TxInterface) error { + for _, assignment := range assignments { + _, err := tx.Exec(queryDeleteRoleAssignmentsByIDs.Query, id, assignment.Type, assignment.ID) + if err != nil { + return fmt.Errorf("failed to remove assignment from role: %w", err) + } + } + return nil + }) +} + +// getRolePermissions retrieves all permissions for a role. +func (s *roleStore) getRolePermissions(dbClient provider.DBClientInterface, id string) ([]string, error) { + results, err := dbClient.Query(queryGetRolePermissions, id) + if err != nil { + return nil, fmt.Errorf("failed to get role permissions: %w", err) + } + + permissions := make([]string, 0) + for _, row := range results { + permission, ok := row["permission"].(string) + if !ok { + return nil, fmt.Errorf("failed to parse permission as string") + } + permissions = append(permissions, permission) + } + + return permissions, nil +} + +// buildRoleSummaryFromResultRow constructs a Role from a database result row. +func buildRoleBasicInfoFromResultRow(row map[string]interface{}) (Role, error) { + fields, err := parseStringFields(row, "role_id", "name", "description", "ou_id") + if err != nil { + return Role{}, err + } + + return Role{ + ID: fields[0], + Name: fields[1], + Description: fields[2], + OrganizationUnitID: fields[3], + }, nil +} + +// addPermissionsToRole adds a list of permissions to a role. +func addPermissionsToRole( + tx dbmodel.TxInterface, + id string, + permissions []string, +) error { + for _, permission := range permissions { + _, err := tx.Exec(queryCreateRolePermission.Query, id, permission) + if err != nil { + return fmt.Errorf("failed to add permission to role: %w", err) + } + } + return nil +} + +// addAssignmentsToRole adds a list of assignments to a role. +func addAssignmentsToRole( + tx dbmodel.TxInterface, + id string, + assignments []RoleAssignment, +) error { + for _, assignment := range assignments { + _, err := tx.Exec(queryCreateRoleAssignment.Query, id, assignment.Type, assignment.ID) + if err != nil { + return fmt.Errorf("failed to add assignment to role: %w", err) + } + } + return nil +} + +// updateRolePermissions updates the permissions assigned to the role by first deleting existing permissions and +// then adding new ones. +func updateRolePermissions( + tx dbmodel.TxInterface, + id string, + permissions []string, +) error { + _, err := tx.Exec(queryDeleteRolePermissions.Query, id) + if err != nil { + return fmt.Errorf("failed to delete existing role permissions: %w", err) + } + + err = addPermissionsToRole(tx, id, permissions) + if err != nil { + return fmt.Errorf("failed to assign permissions to role: %w", err) + } + return nil +} + +// CheckRoleNameExists checks if a role with the given name exists in the specified organization unit. +func (s *roleStore) CheckRoleNameExists(ouID, name string) (bool, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return false, err + } + + results, err := dbClient.Query(queryCheckRoleNameExists, ouID, name) + if err != nil { + return false, fmt.Errorf("failed to check role name existence: %w", err) + } + + return parseBoolFromCount(results) +} + +// CheckRoleNameExistsExcludingID checks if a role with the given name exists in the specified organization unit, +// excluding the role with the given ID. +func (s *roleStore) CheckRoleNameExistsExcludingID(ouID, name, excludeRoleID string) (bool, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return false, err + } + + results, err := dbClient.Query(queryCheckRoleNameExistsExcludingID, ouID, name, excludeRoleID) + if err != nil { + return false, fmt.Errorf("failed to check role name existence: %w", err) + } + + return parseBoolFromCount(results) +} + +// GetAuthorizedPermissions retrieves the permissions that a user is authorized for based on their +// direct role assignments and group memberships. +func (s *roleStore) GetAuthorizedPermissions( + userID string, + groupIDs []string, + requestedPermissions []string, +) ([]string, error) { + dbClient, err := s.getIdentityDBClient() + if err != nil { + return nil, err + } + + // Handle nil groupIDs slice + if groupIDs == nil { + groupIDs = []string{} + } + + // Build dynamic query based on provided parameters + query, args := buildAuthorizedPermissionsQuery(userID, groupIDs, requestedPermissions) + + results, err := dbClient.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("failed to get authorized permissions: %w", err) + } + + permissions := make([]string, 0) + for _, row := range results { + if permission, ok := row["permission"].(string); ok { + permissions = append(permissions, permission) + } + } + + return permissions, nil +} + +// getIdentityDBClient is a helper method to get the database client for the identity database. +func (s *roleStore) getIdentityDBClient() (provider.DBClientInterface, error) { + dbClient, err := s.dbProvider.GetDBClient("identity") + if err != nil { + return nil, fmt.Errorf("failed to get database client: %w", err) + } + return dbClient, nil +} + +// parseCountResult parses a count result from a database query result. +func parseCountResult(results []map[string]interface{}) (int, error) { + if len(results) == 0 { + return 0, nil + } + + if countVal, ok := results[0]["total"].(int64); ok { + return int(countVal), nil + } + return 0, fmt.Errorf("failed to parse total from query result") +} + +// parseBoolFromCount parses a count result and returns true if count > 0. +func parseBoolFromCount(results []map[string]interface{}) (bool, error) { + if len(results) == 0 { + return false, nil + } + + if countVal, ok := results[0]["count"].(int64); ok { + return countVal > 0, nil + } + return false, fmt.Errorf("failed to parse count from query result") +} + +// parseStringField extracts a string field from a database result row. +func parseStringField(row map[string]interface{}, fieldName string) (string, error) { + value, ok := row[fieldName].(string) + if !ok { + return "", fmt.Errorf("failed to parse %s as string", fieldName) + } + return value, nil +} + +// parseStringFields extracts multiple string fields from a database result row. +func parseStringFields(row map[string]interface{}, fieldNames ...string) ([]string, error) { + result := make([]string, len(fieldNames)) + for i, fieldName := range fieldNames { + value, err := parseStringField(row, fieldName) + if err != nil { + return nil, err + } + result[i] = value + } + return result, nil +} + +// executeInTransaction executes a function within a database transaction. +// It automatically handles transaction begin, commit, and rollback on error. +func (s *roleStore) executeInTransaction( + dbClient provider.DBClientInterface, + fn func(tx dbmodel.TxInterface) error, +) error { + tx, err := dbClient.BeginTx() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + err = fn(tx) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return errors.Join(err, fmt.Errorf("failed to rollback transaction: %w", rollbackErr)) + } + return err + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} diff --git a/backend/internal/role/store_test.go b/backend/internal/role/store_test.go new file mode 100644 index 000000000..9574d8918 --- /dev/null +++ b/backend/internal/role/store_test.go @@ -0,0 +1,678 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + dbmodel "github.com/asgardeo/thunder/internal/system/database/model" + "github.com/asgardeo/thunder/tests/mocks/database/clientmock" + "github.com/asgardeo/thunder/tests/mocks/database/modelmock" + "github.com/asgardeo/thunder/tests/mocks/database/providermock" +) + +// mockResult is a simple mock implementation of sql.Result. +type mockResult struct { + lastInsertID int64 + rowsAffected int64 +} + +func (m *mockResult) LastInsertId() (int64, error) { + return m.lastInsertID, nil +} + +func (m *mockResult) RowsAffected() (int64, error) { + return m.rowsAffected, nil +} + +var _ sql.Result = (*mockResult)(nil) + +// RoleStoreTestSuite is the test suite for roleStore. +type RoleStoreTestSuite struct { + suite.Suite + mockDBProvider *providermock.DBProviderInterfaceMock + mockDBClient *clientmock.DBClientInterfaceMock + mockTx *modelmock.TxInterfaceMock + store *roleStore +} + +// TestRoleStoreTestSuite runs the test suite. +func TestRoleStoreTestSuite(t *testing.T) { + suite.Run(t, new(RoleStoreTestSuite)) +} + +// SetupTest sets up the test suite. +func (suite *RoleStoreTestSuite) SetupTest() { + suite.mockDBProvider = providermock.NewDBProviderInterfaceMock(suite.T()) + suite.mockDBClient = clientmock.NewDBClientInterfaceMock(suite.T()) + suite.mockTx = modelmock.NewTxInterfaceMock(suite.T()) + suite.store = &roleStore{ + dbProvider: suite.mockDBProvider, + } +} + +func (suite *RoleStoreTestSuite) TestGetRoleListCount_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleListCount).Return([]map[string]interface{}{ + {"total": int64(10)}, + }, nil) + + count, err := suite.store.GetRoleListCount() + + suite.NoError(err) + suite.Equal(10, count) +} + +func (suite *RoleStoreTestSuite) TestGetRoleListCount_QueryError() { + queryError := errors.New("query error") + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleListCount).Return(nil, queryError) + + count, err := suite.store.GetRoleListCount() + + suite.Error(err) + suite.Equal(0, count) + suite.Contains(err.Error(), "failed to execute count query") +} + +func (suite *RoleStoreTestSuite) TestGetRoleList_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleList, 10, 0).Return([]map[string]interface{}{ + {"role_id": "role1", "name": "Admin", "description": "Admin role", "ou_id": "ou1"}, + {"role_id": "role2", "name": "User", "description": "User role", "ou_id": "ou1"}, + }, nil) + + roles, err := suite.store.GetRoleList(10, 0) + + suite.NoError(err) + suite.Len(roles, 2) + suite.Equal("role1", roles[0].ID) + suite.Equal("Admin", roles[0].Name) +} + +func (suite *RoleStoreTestSuite) TestGetRoleList_QueryError() { + queryError := errors.New("query error") + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleList, 10, 0).Return(nil, queryError) + + roles, err := suite.store.GetRoleList(10, 0) + + suite.Error(err) + suite.Nil(roles) +} + +func (suite *RoleStoreTestSuite) TestGetRoleList_InvalidRowData() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleList, 10, 0).Return([]map[string]interface{}{ + {"role_id": 123, "name": "Admin", "description": "Admin role", "ou_id": "ou1"}, // Invalid role_id type + }, nil) + + roles, err := suite.store.GetRoleList(10, 0) + + suite.Error(err) + suite.Nil(roles) + suite.Contains(err.Error(), "failed to build role from result row") +} + +func (suite *RoleStoreTestSuite) TestCreateRole_Success() { + roleDetail := RoleCreationDetail{ + Name: "Test Role", + Description: "Test Description", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1", "perm2"}, + Assignments: []RoleAssignment{{ID: "user1", Type: AssigneeTypeUser}}, + } + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Exec", queryCreateRole.Query, mock.Anything, "ou1", "Test Role", "Test Description"). + Return(&mockResult{}, nil) + suite.mockTx.On("Exec", queryCreateRolePermission.Query, mock.Anything, "perm1").Return(&mockResult{}, nil) + suite.mockTx.On("Exec", queryCreateRolePermission.Query, mock.Anything, "perm2").Return(&mockResult{}, nil) + suite.mockTx.On("Exec", queryCreateRoleAssignment.Query, mock.Anything, AssigneeTypeUser, "user1"). + Return(&mockResult{}, nil) + suite.mockTx.On("Commit").Return(nil) + + err := suite.store.CreateRole("role1", roleDetail) + + suite.NoError(err) +} + +func (suite *RoleStoreTestSuite) TestCreateRole_ExecError() { + roleDetail := RoleCreationDetail{ + Name: "Test Role", + Description: "Test Description", + OrganizationUnitID: "ou1", + Permissions: []string{}, + Assignments: []RoleAssignment{}, + } + + execError := errors.New("insert failed") + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Exec", queryCreateRole.Query, mock.Anything, "ou1", "Test Role", "Test Description"). + Return(nil, execError) + suite.mockTx.On("Rollback").Return(nil) + + err := suite.store.CreateRole("role1", roleDetail) + + suite.Error(err) + suite.Contains(err.Error(), "failed to execute query") +} + +func (suite *RoleStoreTestSuite) TestGetRole_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleByID, "role1").Return([]map[string]interface{}{ + {"role_id": "role1", "name": "Admin", "description": "Admin role", "ou_id": "ou1"}, + }, nil) + suite.mockDBClient.On("Query", queryGetRolePermissions, "role1").Return([]map[string]interface{}{ + {"permission": "perm1"}, + {"permission": "perm2"}, + }, nil) + + role, err := suite.store.GetRole("role1") + + suite.NoError(err) + suite.Equal("role1", role.ID) + suite.Equal("Admin", role.Name) + suite.Len(role.Permissions, 2) +} + +func (suite *RoleStoreTestSuite) TestGetRole_NotFound() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleByID, "nonexistent").Return([]map[string]interface{}{}, nil) + + role, err := suite.store.GetRole("nonexistent") + + suite.Error(err) + suite.Equal(ErrRoleNotFound, err) + suite.Empty(role.ID) +} + +func (suite *RoleStoreTestSuite) TestGetRole_MultipleResults() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleByID, "role1").Return([]map[string]interface{}{ + {"role_id": "role1", "name": "Admin", "description": "Admin role", "ou_id": "ou1"}, + {"role_id": "role1", "name": "Admin", "description": "Admin role", "ou_id": "ou1"}, + }, nil) + + role, err := suite.store.GetRole("role1") + + suite.Error(err) + suite.Contains(err.Error(), "unexpected number of results") + suite.Empty(role.ID) +} + +func (suite *RoleStoreTestSuite) TestIsRoleExist_Exists() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryCheckRoleExists, "role1").Return([]map[string]interface{}{ + {"count": int64(1)}, + }, nil) + + exists, err := suite.store.IsRoleExist("role1") + + suite.NoError(err) + suite.True(exists) +} + +func (suite *RoleStoreTestSuite) TestIsRoleExist_DoesNotExist() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryCheckRoleExists, "nonexistent").Return([]map[string]interface{}{ + {"count": int64(0)}, + }, nil) + + exists, err := suite.store.IsRoleExist("nonexistent") + + suite.NoError(err) + suite.False(exists) +} + +func (suite *RoleStoreTestSuite) TestGetRoleAssignments_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleAssignments, "role1", 10, 0).Return([]map[string]interface{}{ + {"assignee_id": "user1", "assignee_type": "user"}, + {"assignee_id": "group1", "assignee_type": "group"}, + }, nil) + + assignments, err := suite.store.GetRoleAssignments("role1", 10, 0) + + suite.NoError(err) + suite.Len(assignments, 2) + suite.Equal("user1", assignments[0].ID) + suite.Equal(AssigneeTypeUser, assignments[0].Type) +} + +func (suite *RoleStoreTestSuite) TestGetRoleAssignmentsCount_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetRoleAssignmentsCount, "role1").Return([]map[string]interface{}{ + {"total": int64(5)}, + }, nil) + + count, err := suite.store.GetRoleAssignmentsCount("role1") + + suite.NoError(err) + suite.Equal(5, count) +} + +func (suite *RoleStoreTestSuite) TestUpdateRole_Success() { + roleDetail := RoleUpdateDetail{ + Name: "Updated Role", + Description: "Updated Description", + OrganizationUnitID: "ou1", + Permissions: []string{"perm1"}, + } + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Exec", queryUpdateRole.Query, "ou1", "Updated Role", "Updated Description", "role1"). + Return(&mockResult{rowsAffected: 1}, nil) + suite.mockTx.On("Exec", queryDeleteRolePermissions.Query, "role1").Return(&mockResult{}, nil) + suite.mockTx.On("Exec", queryCreateRolePermission.Query, "role1", "perm1").Return(&mockResult{}, nil) + suite.mockTx.On("Commit").Return(nil) + + err := suite.store.UpdateRole("role1", roleDetail) + + suite.NoError(err) +} + +func (suite *RoleStoreTestSuite) TestUpdateRole_NotFound() { + roleDetail := RoleUpdateDetail{ + Name: "Updated Role", + Description: "Updated Description", + OrganizationUnitID: "ou1", + Permissions: []string{}, + } + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Exec", queryUpdateRole.Query, "ou1", "Updated Role", "Updated Description", "nonexistent"). + Return(&mockResult{rowsAffected: 0}, nil) + suite.mockTx.On("Rollback").Return(nil) + + err := suite.store.UpdateRole("nonexistent", roleDetail) + + suite.Error(err) + suite.Equal(ErrRoleNotFound, err) +} + +func (suite *RoleStoreTestSuite) TestDeleteRole_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Execute", queryDeleteRole, "role1").Return(int64(1), nil) + + err := suite.store.DeleteRole("role1") + + suite.NoError(err) +} + +func (suite *RoleStoreTestSuite) TestDeleteRole_ExecuteError() { + execError := errors.New("delete failed") + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Execute", queryDeleteRole, "role1").Return(int64(0), execError) + + err := suite.store.DeleteRole("role1") + + suite.Error(err) +} + +func (suite *RoleStoreTestSuite) TestAddAssignments_Success() { + assignments := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Exec", queryCreateRoleAssignment.Query, "role1", AssigneeTypeUser, "user1"). + Return(&mockResult{}, nil) + suite.mockTx.On("Commit").Return(nil) + + err := suite.store.AddAssignments("role1", assignments) + + suite.NoError(err) +} + +func (suite *RoleStoreTestSuite) TestRemoveAssignments_Success() { + assignments := []RoleAssignment{ + {ID: "user1", Type: AssigneeTypeUser}, + } + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Exec", queryDeleteRoleAssignmentsByIDs.Query, "role1", AssigneeTypeUser, "user1"). + Return(&mockResult{}, nil) + suite.mockTx.On("Commit").Return(nil) + + err := suite.store.RemoveAssignments("role1", assignments) + + suite.NoError(err) +} + +func (suite *RoleStoreTestSuite) TestCheckRoleNameExists_Exists() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryCheckRoleNameExists, "ou1", "Admin").Return([]map[string]interface{}{ + {"count": int64(1)}, + }, nil) + + exists, err := suite.store.CheckRoleNameExists("ou1", "Admin") + + suite.NoError(err) + suite.True(exists) +} + +func (suite *RoleStoreTestSuite) TestCheckRoleNameExistsExcludingID_DoesNotExist() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryCheckRoleNameExistsExcludingID, "ou1", "Admin", "role1"). + Return([]map[string]interface{}{ + {"count": int64(0)}, + }, nil) + + exists, err := suite.store.CheckRoleNameExistsExcludingID("ou1", "Admin", "role1") + + suite.NoError(err) + suite.False(exists) +} + +func (suite *RoleStoreTestSuite) TestGetAuthorizedPermissions_Success() { + userID := "user1" + groupIDs := []string{"group1"} + requestedPermissions := []string{"perm1", "perm2"} + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + []map[string]interface{}{ + {"permission": "perm1"}, + }, nil) + + permissions, err := suite.store.GetAuthorizedPermissions(userID, groupIDs, requestedPermissions) + + suite.NoError(err) + suite.Len(permissions, 1) + suite.Equal("perm1", permissions[0]) +} + +func (suite *RoleStoreTestSuite) TestGetAuthorizedPermissions_NilGroupsHandled() { + userID := "user1" + requestedPermissions := []string{"perm1"} + + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return([]map[string]interface{}{ + {"permission": "perm1"}, + }, nil) + + permissions, err := suite.store.GetAuthorizedPermissions(userID, nil, requestedPermissions) + + suite.NoError(err) + suite.Len(permissions, 1) +} + +// Test buildRoleBasicInfoFromResultRow + +func (suite *RoleStoreTestSuite) TestBuildRoleBasicInfoFromResultRow_Success() { + row := map[string]interface{}{ + "role_id": "role1", + "name": "Admin", + "description": "Admin role", + "ou_id": "ou1", + } + + role, err := buildRoleBasicInfoFromResultRow(row) + + suite.NoError(err) + suite.Equal("role1", role.ID) + suite.Equal("Admin", role.Name) + suite.Equal("Admin role", role.Description) + suite.Equal("ou1", role.OrganizationUnitID) +} + +func (suite *RoleStoreTestSuite) TestBuildRoleBasicInfoFromResultRow_InvalidData() { + row := map[string]interface{}{ + "role_id": 123, // Invalid type + "name": "Admin", + "description": "Admin role", + "ou_id": "ou1", + } + + role, err := buildRoleBasicInfoFromResultRow(row) + + suite.Error(err) + suite.Empty(role.ID) +} + +// Test Helper Functions + +func (suite *RoleStoreTestSuite) TestGetIdentityDBClient_Success() { + suite.mockDBProvider.On("GetDBClient", "identity").Return(suite.mockDBClient, nil) + + client, err := suite.store.getIdentityDBClient() + + suite.NoError(err) + suite.NotNil(client) + suite.Equal(suite.mockDBClient, client) +} + +func (suite *RoleStoreTestSuite) TestGetIdentityDBClient_Error() { + dbError := errors.New("database connection error") + suite.mockDBProvider.On("GetDBClient", "identity").Return(nil, dbError) + + client, err := suite.store.getIdentityDBClient() + + suite.Error(err) + suite.Nil(client) + suite.Contains(err.Error(), "failed to get database client") +} + +func (suite *RoleStoreTestSuite) TestParseCountResult_Success() { + results := []map[string]interface{}{ + {"total": int64(42)}, + } + + count, err := parseCountResult(results) + + suite.NoError(err) + suite.Equal(42, count) +} + +func (suite *RoleStoreTestSuite) TestParseCountResult_EmptyResults() { + results := []map[string]interface{}{} + + count, err := parseCountResult(results) + + suite.NoError(err) + suite.Equal(0, count) +} + +func (suite *RoleStoreTestSuite) TestParseCountResult_TypeAssertionError() { + results := []map[string]interface{}{ + {"total": "not_a_number"}, + } + + count, err := parseCountResult(results) + + suite.Error(err) + suite.Equal(0, count) + suite.Contains(err.Error(), "failed to parse total") +} + +func (suite *RoleStoreTestSuite) TestParseBoolFromCount_True() { + results := []map[string]interface{}{ + {"count": int64(5)}, + } + + exists, err := parseBoolFromCount(results) + + suite.NoError(err) + suite.True(exists) +} + +func (suite *RoleStoreTestSuite) TestParseBoolFromCount_False() { + results := []map[string]interface{}{ + {"count": int64(0)}, + } + + exists, err := parseBoolFromCount(results) + + suite.NoError(err) + suite.False(exists) +} + +func (suite *RoleStoreTestSuite) TestParseBoolFromCount_EmptyResults() { + results := []map[string]interface{}{} + + exists, err := parseBoolFromCount(results) + + suite.NoError(err) + suite.False(exists) +} + +func (suite *RoleStoreTestSuite) TestParseBoolFromCount_TypeError() { + results := []map[string]interface{}{ + {"count": "invalid"}, + } + + exists, err := parseBoolFromCount(results) + + suite.Error(err) + suite.False(exists) +} + +func (suite *RoleStoreTestSuite) TestParseStringField_Success() { + row := map[string]interface{}{ + "name": "test_value", + } + + value, err := parseStringField(row, "name") + + suite.NoError(err) + suite.Equal("test_value", value) +} + +func (suite *RoleStoreTestSuite) TestParseStringField_TypeError() { + row := map[string]interface{}{ + "name": 123, + } + + value, err := parseStringField(row, "name") + + suite.Error(err) + suite.Empty(value) + suite.Contains(err.Error(), "failed to parse name") +} + +func (suite *RoleStoreTestSuite) TestParseStringFields_Success() { + row := map[string]interface{}{ + "role_id": "role1", + "name": "Admin", + "description": "Admin role", + "ou_id": "ou1", + } + + values, err := parseStringFields(row, "role_id", "name", "description", "ou_id") + + suite.NoError(err) + suite.Len(values, 4) + suite.Equal("role1", values[0]) + suite.Equal("Admin", values[1]) + suite.Equal("Admin role", values[2]) + suite.Equal("ou1", values[3]) +} + +func (suite *RoleStoreTestSuite) TestParseStringFields_PartialError() { + row := map[string]interface{}{ + "role_id": "role1", + "name": 123, // Invalid type + } + + values, err := parseStringFields(row, "role_id", "name") + + suite.Error(err) + suite.Nil(values) + suite.Contains(err.Error(), "failed to parse name") +} + +func (suite *RoleStoreTestSuite) TestExecuteInTransaction_Success() { + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Commit").Return(nil) + + operationCalled := false + err := suite.store.executeInTransaction(suite.mockDBClient, func(tx dbmodel.TxInterface) error { + operationCalled = true + return nil + }) + + suite.NoError(err) + suite.True(operationCalled) +} + +func (suite *RoleStoreTestSuite) TestExecuteInTransaction_BeginError() { + beginError := errors.New("begin transaction failed") + suite.mockDBClient.On("BeginTx").Return(nil, beginError) + + err := suite.store.executeInTransaction(suite.mockDBClient, func(tx dbmodel.TxInterface) error { + suite.Fail("Operation should not be called") + return nil + }) + + suite.Error(err) + suite.Contains(err.Error(), "failed to begin transaction") +} + +func (suite *RoleStoreTestSuite) TestExecuteInTransaction_OperationError() { + operationError := errors.New("operation failed") + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Rollback").Return(nil) + + err := suite.store.executeInTransaction(suite.mockDBClient, func(tx dbmodel.TxInterface) error { + return operationError + }) + + suite.Error(err) + suite.Equal(operationError, err) +} + +func (suite *RoleStoreTestSuite) TestExecuteInTransaction_RollbackError() { + operationError := errors.New("operation failed") + rollbackError := errors.New("rollback failed") + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Rollback").Return(rollbackError) + + err := suite.store.executeInTransaction(suite.mockDBClient, func(tx dbmodel.TxInterface) error { + return operationError + }) + + suite.Error(err) + suite.Contains(err.Error(), "operation failed") + suite.Contains(err.Error(), "rollback") +} + +func (suite *RoleStoreTestSuite) TestExecuteInTransaction_CommitError() { + commitError := errors.New("commit failed") + suite.mockDBClient.On("BeginTx").Return(suite.mockTx, nil) + suite.mockTx.On("Commit").Return(commitError) + + err := suite.store.executeInTransaction(suite.mockDBClient, func(tx dbmodel.TxInterface) error { + return nil + }) + + suite.Error(err) + suite.Contains(err.Error(), "failed to commit transaction") +} diff --git a/backend/internal/role/storeconstants.go b/backend/internal/role/storeconstants.go new file mode 100644 index 000000000..7628e35ed --- /dev/null +++ b/backend/internal/role/storeconstants.go @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "fmt" + "strings" + + dbmodel "github.com/asgardeo/thunder/internal/system/database/model" +) + +var ( + // The table name "ROLE" is quoted to handle reserved keywords in SQL. + // Hence, all queries involving the "ROLE" table use quoted identifiers. + // queryCreateRole creates a new role. + queryCreateRole = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-01", + Query: `INSERT INTO "ROLE" (ROLE_ID, OU_ID, NAME, DESCRIPTION) VALUES ($1, $2, $3, $4)`, + } + + // queryGetRoleByID retrieves a role by ID. + queryGetRoleByID = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-02", + Query: `SELECT ROLE_ID, OU_ID, NAME, DESCRIPTION FROM "ROLE" WHERE ROLE_ID = $1`, + } + + // queryGetRoleList retrieves a list of roles with pagination. + queryGetRoleList = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-03", + Query: `SELECT ROLE_ID, OU_ID, NAME, DESCRIPTION FROM "ROLE" ORDER BY CREATED_AT DESC LIMIT $1 OFFSET $2`, + } + + // queryGetRoleListCount retrieves the total count of roles. + queryGetRoleListCount = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-04", + Query: `SELECT COUNT(*) as total FROM "ROLE"`, + } + + // queryUpdateRole updates a role. + queryUpdateRole = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-05", + Query: `UPDATE "ROLE" SET OU_ID = $1, NAME = $2, DESCRIPTION = $3 WHERE ROLE_ID = $4`, + } + + // queryDeleteRole deletes a role. + queryDeleteRole = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-06", + Query: `DELETE FROM "ROLE" WHERE ROLE_ID = $1`, + } + + // queryCreateRolePermission creates a new role permission. + queryCreateRolePermission = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-07", + Query: `INSERT INTO ROLE_PERMISSION (ROLE_ID, PERMISSION) VALUES ($1, $2)`, + } + + // queryGetRolePermissions retrieves all permissions for a role. + queryGetRolePermissions = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-08", + Query: `SELECT PERMISSION FROM ROLE_PERMISSION WHERE ROLE_ID = $1 ORDER BY CREATED_AT`, + } + + // queryDeleteRolePermissions deletes all permissions for a role. + queryDeleteRolePermissions = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-09", + Query: `DELETE FROM ROLE_PERMISSION WHERE ROLE_ID = $1`, + } + + // queryCreateRoleAssignment creates a new role assignment. + queryCreateRoleAssignment = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-10", + Query: `INSERT INTO ROLE_ASSIGNMENT (ROLE_ID, ASSIGNEE_TYPE, ASSIGNEE_ID) VALUES ($1, $2, $3)`, + } + + // queryGetRoleAssignments retrieves all assignments for a role with pagination. + queryGetRoleAssignments = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-11", + Query: `SELECT ASSIGNEE_ID, ASSIGNEE_TYPE FROM ROLE_ASSIGNMENT + WHERE ROLE_ID = $1 ORDER BY CREATED_AT LIMIT $2 OFFSET $3`, + } + + // queryGetRoleAssignmentsCount retrieves the total count of assignments for a role. + queryGetRoleAssignmentsCount = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-12", + Query: `SELECT COUNT(*) as total FROM ROLE_ASSIGNMENT WHERE ROLE_ID = $1`, + } + + // queryDeleteRoleAssignmentsByIDs deletes specific assignments for a role. + queryDeleteRoleAssignmentsByIDs = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-13", + Query: `DELETE FROM ROLE_ASSIGNMENT WHERE ROLE_ID = $1 AND ASSIGNEE_TYPE = $2 AND ASSIGNEE_ID = $3`, + } + + // queryCheckRoleNameExists checks if a role name already exists for a given organization unit. + queryCheckRoleNameExists = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-14", + Query: `SELECT COUNT(*) as count FROM "ROLE" WHERE OU_ID = $1 AND NAME = $2`, + } + + // queryCheckRoleNameExistsExcludingID checks if a role name exists for an OU excluding a specific role ID. + queryCheckRoleNameExistsExcludingID = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-15", + Query: `SELECT COUNT(*) as count FROM "ROLE" WHERE OU_ID = $1 AND NAME = $2 AND ROLE_ID != $3`, + } + + // queryCheckRoleExists checks if a role exists by its ID. + queryCheckRoleExists = dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-16", + Query: `SELECT COUNT(*) as count FROM "ROLE" WHERE ROLE_ID = $1`, + } +) + +// buildAuthorizedPermissionsQuery constructs a database-specific query to retrieve authorized permissions +// for a user and/or groups from their assigned roles. +// It builds separate queries for PostgreSQL and SQLite to handle array parameters correctly. +func buildAuthorizedPermissionsQuery( + userID string, + groupIDs []string, + requestedPermissions []string, +) (dbmodel.DBQuery, []interface{}) { + // Base query structure + baseQuery := `SELECT DISTINCT rp.PERMISSION + FROM ROLE_PERMISSION rp + INNER JOIN ROLE_ASSIGNMENT ra ON rp.ROLE_ID = ra.ROLE_ID + WHERE ` + + var postgresWhere []string + var sqliteWhere []string + + // Pre-allocate args slice with estimated capacity + argsCapacity := len(groupIDs) + len(requestedPermissions) + if userID != "" { + argsCapacity++ + } + args := make([]interface{}, 0, argsCapacity) + paramIndex := 1 + + // Build user condition if userID is provided + if userID != "" { + postgresWhere = append(postgresWhere, + fmt.Sprintf("(ra.ASSIGNEE_TYPE = 'user' AND ra.ASSIGNEE_ID = $%d)", paramIndex)) + sqliteWhere = append(sqliteWhere, + "(ra.ASSIGNEE_TYPE = 'user' AND ra.ASSIGNEE_ID = ?)") + args = append(args, userID) + paramIndex++ + } + + // Build group condition if groupIDs are provided + if len(groupIDs) > 0 { + groupPlaceholdersPostgres := make([]string, len(groupIDs)) + groupPlaceholdersSqlite := make([]string, len(groupIDs)) + + for i, groupID := range groupIDs { + groupPlaceholdersPostgres[i] = fmt.Sprintf("$%d", paramIndex+i) + groupPlaceholdersSqlite[i] = "?" + args = append(args, groupID) + } + + postgresWhere = append(postgresWhere, + fmt.Sprintf("(ra.ASSIGNEE_TYPE = 'group' AND ra.ASSIGNEE_ID IN (%s))", + strings.Join(groupPlaceholdersPostgres, ","))) + sqliteWhere = append(sqliteWhere, + fmt.Sprintf("(ra.ASSIGNEE_TYPE = 'group' AND ra.ASSIGNEE_ID IN (%s))", + strings.Join(groupPlaceholdersSqlite, ","))) + paramIndex += len(groupIDs) + } + + // Build permission condition + permPlaceholdersPostgres := make([]string, len(requestedPermissions)) + permPlaceholdersSqlite := make([]string, len(requestedPermissions)) + + for i, perm := range requestedPermissions { + permPlaceholdersPostgres[i] = fmt.Sprintf("$%d", paramIndex+i) + permPlaceholdersSqlite[i] = "?" + args = append(args, perm) + } + + // Construct PostgreSQL query + postgresQuery := baseQuery + + "(" + strings.Join(postgresWhere, " OR ") + ") AND " + + fmt.Sprintf("rp.PERMISSION IN (%s)", strings.Join(permPlaceholdersPostgres, ",")) + + " ORDER BY rp.PERMISSION" + + // Construct SQLite query + sqliteQuery := baseQuery + + "(" + strings.Join(sqliteWhere, " OR ") + ") AND " + + fmt.Sprintf("rp.PERMISSION IN (%s)", strings.Join(permPlaceholdersSqlite, ",")) + + " ORDER BY rp.PERMISSION" + + query := dbmodel.DBQuery{ + ID: "RLQ-ROLE_MGT-20", + Query: postgresQuery, + PostgresQuery: postgresQuery, + SQLiteQuery: sqliteQuery, + } + + return query, args +} diff --git a/backend/internal/system/error/serviceerror/error.go b/backend/internal/system/error/serviceerror/error.go index 921c2f51f..4c5342a63 100644 --- a/backend/internal/system/error/serviceerror/error.go +++ b/backend/internal/system/error/serviceerror/error.go @@ -50,3 +50,10 @@ func CustomServiceError(svcError ServiceError, errorDesc string) *ServiceError { } return err } + +// Server errors +var ( + // EncodingError is the error returned when encoding the response. + ErrorEncodingError = "{Code: \"ENC-5000\",Error: \"Encoding error\"," + + "ErrorDescription: \"An error occurred while encoding the response\"}" +) diff --git a/backend/internal/system/utils/sliceutil.go b/backend/internal/system/utils/sliceutil.go index 16666dfae..da3e5cebf 100644 --- a/backend/internal/system/utils/sliceutil.go +++ b/backend/internal/system/utils/sliceutil.go @@ -79,3 +79,23 @@ func MergeInterfaceMaps(dst, src map[string]interface{}) map[string]interface{} } return dst } + +// UniqueStrings returns a slice containing only unique values from the input slice. +// The order of elements is not guaranteed. +func UniqueStrings(input []string) []string { + if input == nil { + return nil + } + + seen := make(map[string]bool, len(input)) + result := make([]string, 0, len(input)) + + for _, item := range input { + if !seen[item] { + seen[item] = true + result = append(result, item) + } + } + + return result +} diff --git a/backend/internal/system/utils/sliceutil_test.go b/backend/internal/system/utils/sliceutil_test.go new file mode 100644 index 000000000..1fc4d7bc3 --- /dev/null +++ b/backend/internal/system/utils/sliceutil_test.go @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type SliceUtilTestSuite struct { + suite.Suite +} + +func TestSliceUtilTestSuite(t *testing.T) { + suite.Run(t, new(SliceUtilTestSuite)) +} + +func (suite *SliceUtilTestSuite) TestUniqueStrings() { + tests := []struct { + name string + input []string + expected []string + }{ + { + name: "Empty slice", + input: []string{}, + expected: []string{}, + }, + { + name: "Nil slice", + input: nil, + expected: nil, + }, + { + name: "No duplicates", + input: []string{"a", "b", "c"}, + expected: []string{"a", "b", "c"}, + }, + { + name: "With duplicates", + input: []string{"a", "b", "a", "c", "b"}, + expected: []string{"a", "b", "c"}, + }, + { + name: "All duplicates", + input: []string{"a", "a", "a"}, + expected: []string{"a"}, + }, + { + name: "Single element", + input: []string{"a"}, + expected: []string{"a"}, + }, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + result := UniqueStrings(tt.input) + + if tt.expected == nil { + assert.Nil(suite.T(), result) + return + } + + // Check length matches + assert.Equal(suite.T(), len(tt.expected), len(result)) + + // Convert result to map for order-independent comparison + resultMap := make(map[string]bool) + for _, v := range result { + resultMap[v] = true + } + + // Verify all expected values are present + for _, v := range tt.expected { + assert.True(suite.T(), resultMap[v], "Expected value %s not found in result", v) + } + }) + } +} + +func (suite *SliceUtilTestSuite) TestDeepCopyMapOfStrings() { + tests := []struct { + name string + input map[string]string + expected map[string]string + }{ + { + name: "Nil map", + input: nil, + expected: nil, + }, + { + name: "Empty map", + input: map[string]string{}, + expected: map[string]string{}, + }, + { + name: "Map with values", + input: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + expected: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + result := DeepCopyMapOfStrings(tt.input) + + if tt.expected == nil { + assert.Nil(suite.T(), result) + return + } + + assert.Equal(suite.T(), tt.expected, result) + + // Verify it's a deep copy (modifying original doesn't affect copy) + if len(tt.input) > 0 { + for k := range tt.input { + tt.input[k] = "modified" + assert.NotEqual(suite.T(), "modified", result[k]) + break + } + } + }) + } +} + +func (suite *SliceUtilTestSuite) TestDeepCopyMapOfStringSlices() { + tests := []struct { + name string + input map[string][]string + expected map[string][]string + }{ + { + name: "Nil map", + input: nil, + expected: nil, + }, + { + name: "Empty map", + input: map[string][]string{}, + expected: map[string][]string{}, + }, + { + name: "Map with values", + input: map[string][]string{ + "key1": {"value1", "value2"}, + "key2": {"value3"}, + }, + expected: map[string][]string{ + "key1": {"value1", "value2"}, + "key2": {"value3"}, + }, + }, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + result := DeepCopyMapOfStringSlices(tt.input) + + if tt.expected == nil { + assert.Nil(suite.T(), result) + return + } + + assert.Equal(suite.T(), tt.expected, result) + + // Verify it's a deep copy (modifying original doesn't affect copy) + if len(tt.input) > 0 { + for k := range tt.input { + if len(tt.input[k]) > 0 { + tt.input[k][0] = "modified" + assert.NotEqual(suite.T(), "modified", result[k][0]) + } + break + } + } + }) + } +} + +func (suite *SliceUtilTestSuite) TestMergeInterfaceMaps() { + tests := []struct { + name string + dst map[string]interface{} + src map[string]interface{} + expected map[string]interface{} + }{ + { + name: "Both nil", + dst: nil, + src: nil, + expected: map[string]interface{}{}, + }, + { + name: "Dst nil, src with values", + dst: nil, + src: map[string]interface{}{"key1": "value1"}, + expected: map[string]interface{}{"key1": "value1"}, + }, + { + name: "Dst with values, src nil", + dst: map[string]interface{}{"key1": "value1"}, + src: nil, + expected: map[string]interface{}{"key1": "value1"}, + }, + { + name: "Both with non-overlapping keys", + dst: map[string]interface{}{"key1": "value1"}, + src: map[string]interface{}{"key2": "value2"}, + expected: map[string]interface{}{"key1": "value1", "key2": "value2"}, + }, + { + name: "Both with overlapping keys - src overrides dst", + dst: map[string]interface{}{"key1": "value1", "key2": "value2"}, + src: map[string]interface{}{"key2": "newValue2", "key3": "value3"}, + expected: map[string]interface{}{"key1": "value1", "key2": "newValue2", "key3": "value3"}, + }, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + result := MergeInterfaceMaps(tt.dst, tt.src) + assert.Equal(suite.T(), tt.expected, result) + }) + } +} diff --git a/backend/tests/mocks/groupmock/GroupServiceInterface_mock.go b/backend/tests/mocks/groupmock/GroupServiceInterface_mock.go new file mode 100644 index 000000000..e032e204a --- /dev/null +++ b/backend/tests/mocks/groupmock/GroupServiceInterface_mock.go @@ -0,0 +1,634 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package groupmock + +import ( + "github.com/asgardeo/thunder/internal/group" + "github.com/asgardeo/thunder/internal/system/error/serviceerror" + mock "github.com/stretchr/testify/mock" +) + +// NewGroupServiceInterfaceMock creates a new instance of GroupServiceInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewGroupServiceInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *GroupServiceInterfaceMock { + mock := &GroupServiceInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// GroupServiceInterfaceMock is an autogenerated mock type for the GroupServiceInterface type +type GroupServiceInterfaceMock struct { + mock.Mock +} + +type GroupServiceInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *GroupServiceInterfaceMock) EXPECT() *GroupServiceInterfaceMock_Expecter { + return &GroupServiceInterfaceMock_Expecter{mock: &_m.Mock} +} + +// CreateGroup provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) CreateGroup(request group.CreateGroupRequest) (*group.Group, *serviceerror.ServiceError) { + ret := _mock.Called(request) + + if len(ret) == 0 { + panic("no return value specified for CreateGroup") + } + + var r0 *group.Group + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(group.CreateGroupRequest) (*group.Group, *serviceerror.ServiceError)); ok { + return returnFunc(request) + } + if returnFunc, ok := ret.Get(0).(func(group.CreateGroupRequest) *group.Group); ok { + r0 = returnFunc(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.Group) + } + } + if returnFunc, ok := ret.Get(1).(func(group.CreateGroupRequest) *serviceerror.ServiceError); ok { + r1 = returnFunc(request) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_CreateGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateGroup' +type GroupServiceInterfaceMock_CreateGroup_Call struct { + *mock.Call +} + +// CreateGroup is a helper method to define mock.On call +// - request group.CreateGroupRequest +func (_e *GroupServiceInterfaceMock_Expecter) CreateGroup(request interface{}) *GroupServiceInterfaceMock_CreateGroup_Call { + return &GroupServiceInterfaceMock_CreateGroup_Call{Call: _e.mock.On("CreateGroup", request)} +} + +func (_c *GroupServiceInterfaceMock_CreateGroup_Call) Run(run func(request group.CreateGroupRequest)) *GroupServiceInterfaceMock_CreateGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 group.CreateGroupRequest + if args[0] != nil { + arg0 = args[0].(group.CreateGroupRequest) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_CreateGroup_Call) Return(group1 *group.Group, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_CreateGroup_Call { + _c.Call.Return(group1, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_CreateGroup_Call) RunAndReturn(run func(request group.CreateGroupRequest) (*group.Group, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_CreateGroup_Call { + _c.Call.Return(run) + return _c +} + +// CreateGroupByPath provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) CreateGroupByPath(handlePath string, request group.CreateGroupByPathRequest) (*group.Group, *serviceerror.ServiceError) { + ret := _mock.Called(handlePath, request) + + if len(ret) == 0 { + panic("no return value specified for CreateGroupByPath") + } + + var r0 *group.Group + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, group.CreateGroupByPathRequest) (*group.Group, *serviceerror.ServiceError)); ok { + return returnFunc(handlePath, request) + } + if returnFunc, ok := ret.Get(0).(func(string, group.CreateGroupByPathRequest) *group.Group); ok { + r0 = returnFunc(handlePath, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.Group) + } + } + if returnFunc, ok := ret.Get(1).(func(string, group.CreateGroupByPathRequest) *serviceerror.ServiceError); ok { + r1 = returnFunc(handlePath, request) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_CreateGroupByPath_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateGroupByPath' +type GroupServiceInterfaceMock_CreateGroupByPath_Call struct { + *mock.Call +} + +// CreateGroupByPath is a helper method to define mock.On call +// - handlePath string +// - request group.CreateGroupByPathRequest +func (_e *GroupServiceInterfaceMock_Expecter) CreateGroupByPath(handlePath interface{}, request interface{}) *GroupServiceInterfaceMock_CreateGroupByPath_Call { + return &GroupServiceInterfaceMock_CreateGroupByPath_Call{Call: _e.mock.On("CreateGroupByPath", handlePath, request)} +} + +func (_c *GroupServiceInterfaceMock_CreateGroupByPath_Call) Run(run func(handlePath string, request group.CreateGroupByPathRequest)) *GroupServiceInterfaceMock_CreateGroupByPath_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 group.CreateGroupByPathRequest + if args[1] != nil { + arg1 = args[1].(group.CreateGroupByPathRequest) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_CreateGroupByPath_Call) Return(group1 *group.Group, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_CreateGroupByPath_Call { + _c.Call.Return(group1, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_CreateGroupByPath_Call) RunAndReturn(run func(handlePath string, request group.CreateGroupByPathRequest) (*group.Group, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_CreateGroupByPath_Call { + _c.Call.Return(run) + return _c +} + +// DeleteGroup provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) DeleteGroup(groupID string) *serviceerror.ServiceError { + ret := _mock.Called(groupID) + + if len(ret) == 0 { + panic("no return value specified for DeleteGroup") + } + + var r0 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) *serviceerror.ServiceError); ok { + r0 = returnFunc(groupID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceerror.ServiceError) + } + } + return r0 +} + +// GroupServiceInterfaceMock_DeleteGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteGroup' +type GroupServiceInterfaceMock_DeleteGroup_Call struct { + *mock.Call +} + +// DeleteGroup is a helper method to define mock.On call +// - groupID string +func (_e *GroupServiceInterfaceMock_Expecter) DeleteGroup(groupID interface{}) *GroupServiceInterfaceMock_DeleteGroup_Call { + return &GroupServiceInterfaceMock_DeleteGroup_Call{Call: _e.mock.On("DeleteGroup", groupID)} +} + +func (_c *GroupServiceInterfaceMock_DeleteGroup_Call) Run(run func(groupID string)) *GroupServiceInterfaceMock_DeleteGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_DeleteGroup_Call) Return(serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_DeleteGroup_Call { + _c.Call.Return(serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_DeleteGroup_Call) RunAndReturn(run func(groupID string) *serviceerror.ServiceError) *GroupServiceInterfaceMock_DeleteGroup_Call { + _c.Call.Return(run) + return _c +} + +// GetGroup provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) GetGroup(groupID string) (*group.Group, *serviceerror.ServiceError) { + ret := _mock.Called(groupID) + + if len(ret) == 0 { + panic("no return value specified for GetGroup") + } + + var r0 *group.Group + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) (*group.Group, *serviceerror.ServiceError)); ok { + return returnFunc(groupID) + } + if returnFunc, ok := ret.Get(0).(func(string) *group.Group); ok { + r0 = returnFunc(groupID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.Group) + } + } + if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { + r1 = returnFunc(groupID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_GetGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroup' +type GroupServiceInterfaceMock_GetGroup_Call struct { + *mock.Call +} + +// GetGroup is a helper method to define mock.On call +// - groupID string +func (_e *GroupServiceInterfaceMock_Expecter) GetGroup(groupID interface{}) *GroupServiceInterfaceMock_GetGroup_Call { + return &GroupServiceInterfaceMock_GetGroup_Call{Call: _e.mock.On("GetGroup", groupID)} +} + +func (_c *GroupServiceInterfaceMock_GetGroup_Call) Run(run func(groupID string)) *GroupServiceInterfaceMock_GetGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroup_Call) Return(group1 *group.Group, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_GetGroup_Call { + _c.Call.Return(group1, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroup_Call) RunAndReturn(run func(groupID string) (*group.Group, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_GetGroup_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupList provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) GetGroupList(limit int, offset int) (*group.GroupListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetGroupList") + } + + var r0 *group.GroupListResponse + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(int, int) (*group.GroupListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(int, int) *group.GroupListResponse); ok { + r0 = returnFunc(limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.GroupListResponse) + } + } + if returnFunc, ok := ret.Get(1).(func(int, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(limit, offset) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_GetGroupList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupList' +type GroupServiceInterfaceMock_GetGroupList_Call struct { + *mock.Call +} + +// GetGroupList is a helper method to define mock.On call +// - limit int +// - offset int +func (_e *GroupServiceInterfaceMock_Expecter) GetGroupList(limit interface{}, offset interface{}) *GroupServiceInterfaceMock_GetGroupList_Call { + return &GroupServiceInterfaceMock_GetGroupList_Call{Call: _e.mock.On("GetGroupList", limit, offset)} +} + +func (_c *GroupServiceInterfaceMock_GetGroupList_Call) Run(run func(limit int, offset int)) *GroupServiceInterfaceMock_GetGroupList_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroupList_Call) Return(groupListResponse *group.GroupListResponse, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_GetGroupList_Call { + _c.Call.Return(groupListResponse, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroupList_Call) RunAndReturn(run func(limit int, offset int) (*group.GroupListResponse, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_GetGroupList_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupMembers provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) GetGroupMembers(groupID string, limit int, offset int) (*group.MemberListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(groupID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetGroupMembers") + } + + var r0 *group.MemberListResponse + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, int, int) (*group.MemberListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(groupID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(string, int, int) *group.MemberListResponse); ok { + r0 = returnFunc(groupID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.MemberListResponse) + } + } + if returnFunc, ok := ret.Get(1).(func(string, int, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(groupID, limit, offset) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_GetGroupMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupMembers' +type GroupServiceInterfaceMock_GetGroupMembers_Call struct { + *mock.Call +} + +// GetGroupMembers is a helper method to define mock.On call +// - groupID string +// - limit int +// - offset int +func (_e *GroupServiceInterfaceMock_Expecter) GetGroupMembers(groupID interface{}, limit interface{}, offset interface{}) *GroupServiceInterfaceMock_GetGroupMembers_Call { + return &GroupServiceInterfaceMock_GetGroupMembers_Call{Call: _e.mock.On("GetGroupMembers", groupID, limit, offset)} +} + +func (_c *GroupServiceInterfaceMock_GetGroupMembers_Call) Run(run func(groupID string, limit int, offset int)) *GroupServiceInterfaceMock_GetGroupMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroupMembers_Call) Return(memberListResponse *group.MemberListResponse, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_GetGroupMembers_Call { + _c.Call.Return(memberListResponse, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroupMembers_Call) RunAndReturn(run func(groupID string, limit int, offset int) (*group.MemberListResponse, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_GetGroupMembers_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupsByPath provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) GetGroupsByPath(handlePath string, limit int, offset int) (*group.GroupListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(handlePath, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetGroupsByPath") + } + + var r0 *group.GroupListResponse + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, int, int) (*group.GroupListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(handlePath, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(string, int, int) *group.GroupListResponse); ok { + r0 = returnFunc(handlePath, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.GroupListResponse) + } + } + if returnFunc, ok := ret.Get(1).(func(string, int, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(handlePath, limit, offset) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_GetGroupsByPath_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupsByPath' +type GroupServiceInterfaceMock_GetGroupsByPath_Call struct { + *mock.Call +} + +// GetGroupsByPath is a helper method to define mock.On call +// - handlePath string +// - limit int +// - offset int +func (_e *GroupServiceInterfaceMock_Expecter) GetGroupsByPath(handlePath interface{}, limit interface{}, offset interface{}) *GroupServiceInterfaceMock_GetGroupsByPath_Call { + return &GroupServiceInterfaceMock_GetGroupsByPath_Call{Call: _e.mock.On("GetGroupsByPath", handlePath, limit, offset)} +} + +func (_c *GroupServiceInterfaceMock_GetGroupsByPath_Call) Run(run func(handlePath string, limit int, offset int)) *GroupServiceInterfaceMock_GetGroupsByPath_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroupsByPath_Call) Return(groupListResponse *group.GroupListResponse, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_GetGroupsByPath_Call { + _c.Call.Return(groupListResponse, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_GetGroupsByPath_Call) RunAndReturn(run func(handlePath string, limit int, offset int) (*group.GroupListResponse, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_GetGroupsByPath_Call { + _c.Call.Return(run) + return _c +} + +// UpdateGroup provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) UpdateGroup(groupID string, request group.UpdateGroupRequest) (*group.Group, *serviceerror.ServiceError) { + ret := _mock.Called(groupID, request) + + if len(ret) == 0 { + panic("no return value specified for UpdateGroup") + } + + var r0 *group.Group + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, group.UpdateGroupRequest) (*group.Group, *serviceerror.ServiceError)); ok { + return returnFunc(groupID, request) + } + if returnFunc, ok := ret.Get(0).(func(string, group.UpdateGroupRequest) *group.Group); ok { + r0 = returnFunc(groupID, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*group.Group) + } + } + if returnFunc, ok := ret.Get(1).(func(string, group.UpdateGroupRequest) *serviceerror.ServiceError); ok { + r1 = returnFunc(groupID, request) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// GroupServiceInterfaceMock_UpdateGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateGroup' +type GroupServiceInterfaceMock_UpdateGroup_Call struct { + *mock.Call +} + +// UpdateGroup is a helper method to define mock.On call +// - groupID string +// - request group.UpdateGroupRequest +func (_e *GroupServiceInterfaceMock_Expecter) UpdateGroup(groupID interface{}, request interface{}) *GroupServiceInterfaceMock_UpdateGroup_Call { + return &GroupServiceInterfaceMock_UpdateGroup_Call{Call: _e.mock.On("UpdateGroup", groupID, request)} +} + +func (_c *GroupServiceInterfaceMock_UpdateGroup_Call) Run(run func(groupID string, request group.UpdateGroupRequest)) *GroupServiceInterfaceMock_UpdateGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 group.UpdateGroupRequest + if args[1] != nil { + arg1 = args[1].(group.UpdateGroupRequest) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_UpdateGroup_Call) Return(group1 *group.Group, serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_UpdateGroup_Call { + _c.Call.Return(group1, serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_UpdateGroup_Call) RunAndReturn(run func(groupID string, request group.UpdateGroupRequest) (*group.Group, *serviceerror.ServiceError)) *GroupServiceInterfaceMock_UpdateGroup_Call { + _c.Call.Return(run) + return _c +} + +// ValidateGroupIDs provides a mock function for the type GroupServiceInterfaceMock +func (_mock *GroupServiceInterfaceMock) ValidateGroupIDs(groupIDs []string) *serviceerror.ServiceError { + ret := _mock.Called(groupIDs) + + if len(ret) == 0 { + panic("no return value specified for ValidateGroupIDs") + } + + var r0 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func([]string) *serviceerror.ServiceError); ok { + r0 = returnFunc(groupIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceerror.ServiceError) + } + } + return r0 +} + +// GroupServiceInterfaceMock_ValidateGroupIDs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateGroupIDs' +type GroupServiceInterfaceMock_ValidateGroupIDs_Call struct { + *mock.Call +} + +// ValidateGroupIDs is a helper method to define mock.On call +// - groupIDs []string +func (_e *GroupServiceInterfaceMock_Expecter) ValidateGroupIDs(groupIDs interface{}) *GroupServiceInterfaceMock_ValidateGroupIDs_Call { + return &GroupServiceInterfaceMock_ValidateGroupIDs_Call{Call: _e.mock.On("ValidateGroupIDs", groupIDs)} +} + +func (_c *GroupServiceInterfaceMock_ValidateGroupIDs_Call) Run(run func(groupIDs []string)) *GroupServiceInterfaceMock_ValidateGroupIDs_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []string + if args[0] != nil { + arg0 = args[0].([]string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *GroupServiceInterfaceMock_ValidateGroupIDs_Call) Return(serviceError *serviceerror.ServiceError) *GroupServiceInterfaceMock_ValidateGroupIDs_Call { + _c.Call.Return(serviceError) + return _c +} + +func (_c *GroupServiceInterfaceMock_ValidateGroupIDs_Call) RunAndReturn(run func(groupIDs []string) *serviceerror.ServiceError) *GroupServiceInterfaceMock_ValidateGroupIDs_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/tests/mocks/groupmock/groupStoreInterface_mock.go b/backend/tests/mocks/groupmock/groupStoreInterface_mock.go new file mode 100644 index 000000000..d57c8eafa --- /dev/null +++ b/backend/tests/mocks/groupmock/groupStoreInterface_mock.go @@ -0,0 +1,821 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package groupmock + +import ( + "github.com/asgardeo/thunder/internal/group" + mock "github.com/stretchr/testify/mock" +) + +// newGroupStoreInterfaceMock creates a new instance of groupStoreInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newGroupStoreInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *groupStoreInterfaceMock { + mock := &groupStoreInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// groupStoreInterfaceMock is an autogenerated mock type for the groupStoreInterface type +type groupStoreInterfaceMock struct { + mock.Mock +} + +type groupStoreInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *groupStoreInterfaceMock) EXPECT() *groupStoreInterfaceMock_Expecter { + return &groupStoreInterfaceMock_Expecter{mock: &_m.Mock} +} + +// CheckGroupNameConflictForCreate provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) CheckGroupNameConflictForCreate(name string, organizationUnitID string) error { + ret := _mock.Called(name, organizationUnitID) + + if len(ret) == 0 { + panic("no return value specified for CheckGroupNameConflictForCreate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, string) error); ok { + r0 = returnFunc(name, organizationUnitID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckGroupNameConflictForCreate' +type groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call struct { + *mock.Call +} + +// CheckGroupNameConflictForCreate is a helper method to define mock.On call +// - name string +// - organizationUnitID string +func (_e *groupStoreInterfaceMock_Expecter) CheckGroupNameConflictForCreate(name interface{}, organizationUnitID interface{}) *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call { + return &groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call{Call: _e.mock.On("CheckGroupNameConflictForCreate", name, organizationUnitID)} +} + +func (_c *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call) Run(run func(name string, organizationUnitID string)) *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call) Return(err error) *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call) RunAndReturn(run func(name string, organizationUnitID string) error) *groupStoreInterfaceMock_CheckGroupNameConflictForCreate_Call { + _c.Call.Return(run) + return _c +} + +// CheckGroupNameConflictForUpdate provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) CheckGroupNameConflictForUpdate(name string, organizationUnitID string, groupID string) error { + ret := _mock.Called(name, organizationUnitID, groupID) + + if len(ret) == 0 { + panic("no return value specified for CheckGroupNameConflictForUpdate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, string, string) error); ok { + r0 = returnFunc(name, organizationUnitID, groupID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckGroupNameConflictForUpdate' +type groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call struct { + *mock.Call +} + +// CheckGroupNameConflictForUpdate is a helper method to define mock.On call +// - name string +// - organizationUnitID string +// - groupID string +func (_e *groupStoreInterfaceMock_Expecter) CheckGroupNameConflictForUpdate(name interface{}, organizationUnitID interface{}, groupID interface{}) *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call { + return &groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call{Call: _e.mock.On("CheckGroupNameConflictForUpdate", name, organizationUnitID, groupID)} +} + +func (_c *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call) Run(run func(name string, organizationUnitID string, groupID string)) *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call) Return(err error) *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call) RunAndReturn(run func(name string, organizationUnitID string, groupID string) error) *groupStoreInterfaceMock_CheckGroupNameConflictForUpdate_Call { + _c.Call.Return(run) + return _c +} + +// CreateGroup provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) CreateGroup(group1 group.GroupDAO) error { + ret := _mock.Called(group1) + + if len(ret) == 0 { + panic("no return value specified for CreateGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(group.GroupDAO) error); ok { + r0 = returnFunc(group1) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// groupStoreInterfaceMock_CreateGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateGroup' +type groupStoreInterfaceMock_CreateGroup_Call struct { + *mock.Call +} + +// CreateGroup is a helper method to define mock.On call +// - group1 group.GroupDAO +func (_e *groupStoreInterfaceMock_Expecter) CreateGroup(group1 interface{}) *groupStoreInterfaceMock_CreateGroup_Call { + return &groupStoreInterfaceMock_CreateGroup_Call{Call: _e.mock.On("CreateGroup", group1)} +} + +func (_c *groupStoreInterfaceMock_CreateGroup_Call) Run(run func(group1 group.GroupDAO)) *groupStoreInterfaceMock_CreateGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 group.GroupDAO + if args[0] != nil { + arg0 = args[0].(group.GroupDAO) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_CreateGroup_Call) Return(err error) *groupStoreInterfaceMock_CreateGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *groupStoreInterfaceMock_CreateGroup_Call) RunAndReturn(run func(group1 group.GroupDAO) error) *groupStoreInterfaceMock_CreateGroup_Call { + _c.Call.Return(run) + return _c +} + +// DeleteGroup provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) DeleteGroup(id string) error { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for DeleteGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string) error); ok { + r0 = returnFunc(id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// groupStoreInterfaceMock_DeleteGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteGroup' +type groupStoreInterfaceMock_DeleteGroup_Call struct { + *mock.Call +} + +// DeleteGroup is a helper method to define mock.On call +// - id string +func (_e *groupStoreInterfaceMock_Expecter) DeleteGroup(id interface{}) *groupStoreInterfaceMock_DeleteGroup_Call { + return &groupStoreInterfaceMock_DeleteGroup_Call{Call: _e.mock.On("DeleteGroup", id)} +} + +func (_c *groupStoreInterfaceMock_DeleteGroup_Call) Run(run func(id string)) *groupStoreInterfaceMock_DeleteGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_DeleteGroup_Call) Return(err error) *groupStoreInterfaceMock_DeleteGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *groupStoreInterfaceMock_DeleteGroup_Call) RunAndReturn(run func(id string) error) *groupStoreInterfaceMock_DeleteGroup_Call { + _c.Call.Return(run) + return _c +} + +// GetGroup provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroup(id string) (group.GroupDAO, error) { + ret := _mock.Called(id) + + if len(ret) == 0 { + panic("no return value specified for GetGroup") + } + + var r0 group.GroupDAO + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (group.GroupDAO, error)); ok { + return returnFunc(id) + } + if returnFunc, ok := ret.Get(0).(func(string) group.GroupDAO); ok { + r0 = returnFunc(id) + } else { + r0 = ret.Get(0).(group.GroupDAO) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroup' +type groupStoreInterfaceMock_GetGroup_Call struct { + *mock.Call +} + +// GetGroup is a helper method to define mock.On call +// - id string +func (_e *groupStoreInterfaceMock_Expecter) GetGroup(id interface{}) *groupStoreInterfaceMock_GetGroup_Call { + return &groupStoreInterfaceMock_GetGroup_Call{Call: _e.mock.On("GetGroup", id)} +} + +func (_c *groupStoreInterfaceMock_GetGroup_Call) Run(run func(id string)) *groupStoreInterfaceMock_GetGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroup_Call) Return(groupDAO group.GroupDAO, err error) *groupStoreInterfaceMock_GetGroup_Call { + _c.Call.Return(groupDAO, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroup_Call) RunAndReturn(run func(id string) (group.GroupDAO, error)) *groupStoreInterfaceMock_GetGroup_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupList provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroupList(limit int, offset int) ([]group.GroupBasicDAO, error) { + ret := _mock.Called(limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetGroupList") + } + + var r0 []group.GroupBasicDAO + var r1 error + if returnFunc, ok := ret.Get(0).(func(int, int) ([]group.GroupBasicDAO, error)); ok { + return returnFunc(limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(int, int) []group.GroupBasicDAO); ok { + r0 = returnFunc(limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]group.GroupBasicDAO) + } + } + if returnFunc, ok := ret.Get(1).(func(int, int) error); ok { + r1 = returnFunc(limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroupList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupList' +type groupStoreInterfaceMock_GetGroupList_Call struct { + *mock.Call +} + +// GetGroupList is a helper method to define mock.On call +// - limit int +// - offset int +func (_e *groupStoreInterfaceMock_Expecter) GetGroupList(limit interface{}, offset interface{}) *groupStoreInterfaceMock_GetGroupList_Call { + return &groupStoreInterfaceMock_GetGroupList_Call{Call: _e.mock.On("GetGroupList", limit, offset)} +} + +func (_c *groupStoreInterfaceMock_GetGroupList_Call) Run(run func(limit int, offset int)) *groupStoreInterfaceMock_GetGroupList_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupList_Call) Return(groupBasicDAOs []group.GroupBasicDAO, err error) *groupStoreInterfaceMock_GetGroupList_Call { + _c.Call.Return(groupBasicDAOs, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupList_Call) RunAndReturn(run func(limit int, offset int) ([]group.GroupBasicDAO, error)) *groupStoreInterfaceMock_GetGroupList_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupListCount provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroupListCount() (int, error) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetGroupListCount") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func() (int, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() int); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroupListCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupListCount' +type groupStoreInterfaceMock_GetGroupListCount_Call struct { + *mock.Call +} + +// GetGroupListCount is a helper method to define mock.On call +func (_e *groupStoreInterfaceMock_Expecter) GetGroupListCount() *groupStoreInterfaceMock_GetGroupListCount_Call { + return &groupStoreInterfaceMock_GetGroupListCount_Call{Call: _e.mock.On("GetGroupListCount")} +} + +func (_c *groupStoreInterfaceMock_GetGroupListCount_Call) Run(run func()) *groupStoreInterfaceMock_GetGroupListCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupListCount_Call) Return(n int, err error) *groupStoreInterfaceMock_GetGroupListCount_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupListCount_Call) RunAndReturn(run func() (int, error)) *groupStoreInterfaceMock_GetGroupListCount_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupMemberCount provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroupMemberCount(groupID string) (int, error) { + ret := _mock.Called(groupID) + + if len(ret) == 0 { + panic("no return value specified for GetGroupMemberCount") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (int, error)); ok { + return returnFunc(groupID) + } + if returnFunc, ok := ret.Get(0).(func(string) int); ok { + r0 = returnFunc(groupID) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(groupID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroupMemberCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupMemberCount' +type groupStoreInterfaceMock_GetGroupMemberCount_Call struct { + *mock.Call +} + +// GetGroupMemberCount is a helper method to define mock.On call +// - groupID string +func (_e *groupStoreInterfaceMock_Expecter) GetGroupMemberCount(groupID interface{}) *groupStoreInterfaceMock_GetGroupMemberCount_Call { + return &groupStoreInterfaceMock_GetGroupMemberCount_Call{Call: _e.mock.On("GetGroupMemberCount", groupID)} +} + +func (_c *groupStoreInterfaceMock_GetGroupMemberCount_Call) Run(run func(groupID string)) *groupStoreInterfaceMock_GetGroupMemberCount_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupMemberCount_Call) Return(n int, err error) *groupStoreInterfaceMock_GetGroupMemberCount_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupMemberCount_Call) RunAndReturn(run func(groupID string) (int, error)) *groupStoreInterfaceMock_GetGroupMemberCount_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupMembers provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroupMembers(groupID string, limit int, offset int) ([]group.Member, error) { + ret := _mock.Called(groupID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetGroupMembers") + } + + var r0 []group.Member + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, int, int) ([]group.Member, error)); ok { + return returnFunc(groupID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(string, int, int) []group.Member); ok { + r0 = returnFunc(groupID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]group.Member) + } + } + if returnFunc, ok := ret.Get(1).(func(string, int, int) error); ok { + r1 = returnFunc(groupID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroupMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupMembers' +type groupStoreInterfaceMock_GetGroupMembers_Call struct { + *mock.Call +} + +// GetGroupMembers is a helper method to define mock.On call +// - groupID string +// - limit int +// - offset int +func (_e *groupStoreInterfaceMock_Expecter) GetGroupMembers(groupID interface{}, limit interface{}, offset interface{}) *groupStoreInterfaceMock_GetGroupMembers_Call { + return &groupStoreInterfaceMock_GetGroupMembers_Call{Call: _e.mock.On("GetGroupMembers", groupID, limit, offset)} +} + +func (_c *groupStoreInterfaceMock_GetGroupMembers_Call) Run(run func(groupID string, limit int, offset int)) *groupStoreInterfaceMock_GetGroupMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupMembers_Call) Return(members []group.Member, err error) *groupStoreInterfaceMock_GetGroupMembers_Call { + _c.Call.Return(members, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupMembers_Call) RunAndReturn(run func(groupID string, limit int, offset int) ([]group.Member, error)) *groupStoreInterfaceMock_GetGroupMembers_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupsByOrganizationUnit provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroupsByOrganizationUnit(organizationUnitID string, limit int, offset int) ([]group.GroupBasicDAO, error) { + ret := _mock.Called(organizationUnitID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for GetGroupsByOrganizationUnit") + } + + var r0 []group.GroupBasicDAO + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, int, int) ([]group.GroupBasicDAO, error)); ok { + return returnFunc(organizationUnitID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(string, int, int) []group.GroupBasicDAO); ok { + r0 = returnFunc(organizationUnitID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]group.GroupBasicDAO) + } + } + if returnFunc, ok := ret.Get(1).(func(string, int, int) error); ok { + r1 = returnFunc(organizationUnitID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupsByOrganizationUnit' +type groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call struct { + *mock.Call +} + +// GetGroupsByOrganizationUnit is a helper method to define mock.On call +// - organizationUnitID string +// - limit int +// - offset int +func (_e *groupStoreInterfaceMock_Expecter) GetGroupsByOrganizationUnit(organizationUnitID interface{}, limit interface{}, offset interface{}) *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call { + return &groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call{Call: _e.mock.On("GetGroupsByOrganizationUnit", organizationUnitID, limit, offset)} +} + +func (_c *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call) Run(run func(organizationUnitID string, limit int, offset int)) *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call) Return(groupBasicDAOs []group.GroupBasicDAO, err error) *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call { + _c.Call.Return(groupBasicDAOs, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call) RunAndReturn(run func(organizationUnitID string, limit int, offset int) ([]group.GroupBasicDAO, error)) *groupStoreInterfaceMock_GetGroupsByOrganizationUnit_Call { + _c.Call.Return(run) + return _c +} + +// GetGroupsByOrganizationUnitCount provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) GetGroupsByOrganizationUnitCount(organizationUnitID string) (int, error) { + ret := _mock.Called(organizationUnitID) + + if len(ret) == 0 { + panic("no return value specified for GetGroupsByOrganizationUnitCount") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (int, error)); ok { + return returnFunc(organizationUnitID) + } + if returnFunc, ok := ret.Get(0).(func(string) int); ok { + r0 = returnFunc(organizationUnitID) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(organizationUnitID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGroupsByOrganizationUnitCount' +type groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call struct { + *mock.Call +} + +// GetGroupsByOrganizationUnitCount is a helper method to define mock.On call +// - organizationUnitID string +func (_e *groupStoreInterfaceMock_Expecter) GetGroupsByOrganizationUnitCount(organizationUnitID interface{}) *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call { + return &groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call{Call: _e.mock.On("GetGroupsByOrganizationUnitCount", organizationUnitID)} +} + +func (_c *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call) Run(run func(organizationUnitID string)) *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call) Return(n int, err error) *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call) RunAndReturn(run func(organizationUnitID string) (int, error)) *groupStoreInterfaceMock_GetGroupsByOrganizationUnitCount_Call { + _c.Call.Return(run) + return _c +} + +// UpdateGroup provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) UpdateGroup(group1 group.GroupDAO) error { + ret := _mock.Called(group1) + + if len(ret) == 0 { + panic("no return value specified for UpdateGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(group.GroupDAO) error); ok { + r0 = returnFunc(group1) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// groupStoreInterfaceMock_UpdateGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateGroup' +type groupStoreInterfaceMock_UpdateGroup_Call struct { + *mock.Call +} + +// UpdateGroup is a helper method to define mock.On call +// - group1 group.GroupDAO +func (_e *groupStoreInterfaceMock_Expecter) UpdateGroup(group1 interface{}) *groupStoreInterfaceMock_UpdateGroup_Call { + return &groupStoreInterfaceMock_UpdateGroup_Call{Call: _e.mock.On("UpdateGroup", group1)} +} + +func (_c *groupStoreInterfaceMock_UpdateGroup_Call) Run(run func(group1 group.GroupDAO)) *groupStoreInterfaceMock_UpdateGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 group.GroupDAO + if args[0] != nil { + arg0 = args[0].(group.GroupDAO) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_UpdateGroup_Call) Return(err error) *groupStoreInterfaceMock_UpdateGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *groupStoreInterfaceMock_UpdateGroup_Call) RunAndReturn(run func(group1 group.GroupDAO) error) *groupStoreInterfaceMock_UpdateGroup_Call { + _c.Call.Return(run) + return _c +} + +// ValidateGroupIDs provides a mock function for the type groupStoreInterfaceMock +func (_mock *groupStoreInterfaceMock) ValidateGroupIDs(groupIDs []string) ([]string, error) { + ret := _mock.Called(groupIDs) + + if len(ret) == 0 { + panic("no return value specified for ValidateGroupIDs") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func([]string) ([]string, error)); ok { + return returnFunc(groupIDs) + } + if returnFunc, ok := ret.Get(0).(func([]string) []string); ok { + r0 = returnFunc(groupIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func([]string) error); ok { + r1 = returnFunc(groupIDs) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// groupStoreInterfaceMock_ValidateGroupIDs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateGroupIDs' +type groupStoreInterfaceMock_ValidateGroupIDs_Call struct { + *mock.Call +} + +// ValidateGroupIDs is a helper method to define mock.On call +// - groupIDs []string +func (_e *groupStoreInterfaceMock_Expecter) ValidateGroupIDs(groupIDs interface{}) *groupStoreInterfaceMock_ValidateGroupIDs_Call { + return &groupStoreInterfaceMock_ValidateGroupIDs_Call{Call: _e.mock.On("ValidateGroupIDs", groupIDs)} +} + +func (_c *groupStoreInterfaceMock_ValidateGroupIDs_Call) Run(run func(groupIDs []string)) *groupStoreInterfaceMock_ValidateGroupIDs_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []string + if args[0] != nil { + arg0 = args[0].([]string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *groupStoreInterfaceMock_ValidateGroupIDs_Call) Return(strings []string, err error) *groupStoreInterfaceMock_ValidateGroupIDs_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *groupStoreInterfaceMock_ValidateGroupIDs_Call) RunAndReturn(run func(groupIDs []string) ([]string, error)) *groupStoreInterfaceMock_ValidateGroupIDs_Call { + _c.Call.Return(run) + return _c +} diff --git a/tests/integration/resources/dbscripts/thunderdb/postgres.sql b/tests/integration/resources/dbscripts/thunderdb/postgres.sql index db03c4630..f6984fad7 100644 --- a/tests/integration/resources/dbscripts/thunderdb/postgres.sql +++ b/tests/integration/resources/dbscripts/thunderdb/postgres.sql @@ -55,6 +55,50 @@ CREATE TABLE GROUP_MEMBER_REFERENCE ( FOREIGN KEY (GROUP_ID) REFERENCES "GROUP" (GROUP_ID) ON DELETE CASCADE ); +-- Table to store Roles +CREATE TABLE "ROLE" ( + ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ROLE_ID VARCHAR(36) UNIQUE NOT NULL, + OU_ID VARCHAR(36) NOT NULL, + NAME VARCHAR(50) NOT NULL, + DESCRIPTION VARCHAR(255), + CREATED_AT TIMESTAMPTZ DEFAULT NOW(), + UPDATED_AT TIMESTAMPTZ DEFAULT NOW(), + CONSTRAINT unique_role_ou_name UNIQUE (OU_ID, NAME) +); + +-- Table to store Role permissions +CREATE TABLE ROLE_PERMISSION ( + ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ROLE_ID VARCHAR(36) NOT NULL, + PERMISSION VARCHAR(100) NOT NULL, + CREATED_AT TIMESTAMPTZ DEFAULT NOW(), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_permission UNIQUE (ROLE_ID, PERMISSION) +); + +-- Table to store Role assignments (to users and groups) +CREATE TABLE ROLE_ASSIGNMENT ( + ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ROLE_ID VARCHAR(36) NOT NULL, + ASSIGNEE_TYPE VARCHAR(5) NOT NULL CHECK (ASSIGNEE_TYPE IN ('user', 'group')), + ASSIGNEE_ID VARCHAR(36) NOT NULL, + CREATED_AT TIMESTAMPTZ DEFAULT NOW(), + UPDATED_AT TIMESTAMPTZ DEFAULT NOW(), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_assignment UNIQUE (ROLE_ID, ASSIGNEE_TYPE, ASSIGNEE_ID) +); + +-- Indexes for authorization queries + +-- Index for finding all roles assigned to a specific assignee +CREATE INDEX idx_role_assignment_assignee +ON ROLE_ASSIGNMENT (ASSIGNEE_ID, ASSIGNEE_TYPE); + +-- Index for finding all permissions for a specific role +CREATE INDEX idx_role_permission_role +ON ROLE_PERMISSION (ROLE_ID); + -- Table to store basic service provider (app) details. CREATE TABLE SP_APP ( ID INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, diff --git a/tests/integration/resources/dbscripts/thunderdb/sqlite.sql b/tests/integration/resources/dbscripts/thunderdb/sqlite.sql index 9f6edd7e9..ae79918f9 100644 --- a/tests/integration/resources/dbscripts/thunderdb/sqlite.sql +++ b/tests/integration/resources/dbscripts/thunderdb/sqlite.sql @@ -45,6 +45,50 @@ CREATE TABLE GROUP_MEMBER_REFERENCE ( FOREIGN KEY (GROUP_ID) REFERENCES "GROUP" (GROUP_ID) ON DELETE CASCADE ); +-- Table to store Roles +CREATE TABLE "ROLE" ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + ROLE_ID VARCHAR(36) UNIQUE NOT NULL, + OU_ID VARCHAR(36) NOT NULL, + NAME VARCHAR(50) NOT NULL, + DESCRIPTION VARCHAR(255), + CREATED_AT TEXT DEFAULT (datetime('now')), + UPDATED_AT TEXT DEFAULT (datetime('now')), + CONSTRAINT unique_role_ou_name UNIQUE (OU_ID, NAME) +); + +-- Table to store Role permissions +CREATE TABLE ROLE_PERMISSION ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + ROLE_ID VARCHAR(36) NOT NULL, + PERMISSION VARCHAR(100) NOT NULL, + CREATED_AT TEXT DEFAULT (datetime('now')), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_permission UNIQUE (ROLE_ID, PERMISSION) +); + +-- Table to store Role assignments (to users and groups) +CREATE TABLE ROLE_ASSIGNMENT ( + ID INTEGER PRIMARY KEY AUTOINCREMENT, + ROLE_ID VARCHAR(36) NOT NULL, + ASSIGNEE_TYPE VARCHAR(5) NOT NULL CHECK (ASSIGNEE_TYPE IN ('user', 'group')), + ASSIGNEE_ID VARCHAR(36) NOT NULL, + CREATED_AT TEXT DEFAULT (datetime('now')), + UPDATED_AT TEXT DEFAULT (datetime('now')), + FOREIGN KEY (ROLE_ID) REFERENCES "ROLE" (ROLE_ID) ON DELETE CASCADE, + CONSTRAINT unique_role_assignment UNIQUE (ROLE_ID, ASSIGNEE_TYPE, ASSIGNEE_ID) +); + +-- Indexes for authorization queries + +-- Index for finding all roles assigned to a specific assignee +CREATE INDEX idx_role_assignment_assignee +ON ROLE_ASSIGNMENT (ASSIGNEE_ID, ASSIGNEE_TYPE); + +-- Index for finding all permissions for a specific role +CREATE INDEX idx_role_permission_role +ON ROLE_PERMISSION (ROLE_ID); + -- Table to store basic service provider (app) details. CREATE TABLE SP_APP ( ID INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/tests/integration/role/model.go b/tests/integration/role/model.go new file mode 100644 index 000000000..49a27c619 --- /dev/null +++ b/tests/integration/role/model.go @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +// AssigneeType represents the type of assignee (user or group) +type AssigneeType string + +const ( + AssigneeTypeUser AssigneeType = "user" + AssigneeTypeGroup AssigneeType = "group" +) + +// Assignment represents a role assignment +type Assignment struct { + ID string `json:"id"` + Type AssigneeType `json:"type"` + Display string `json:"display,omitempty"` // Display name (only included with include=display parameter) +} + +// CreateRoleRequest represents the request to create a role +type CreateRoleRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` + Assignments []Assignment `json:"assignments,omitempty"` +} + +// UpdateRoleRequest represents the request to update a role +type UpdateRoleRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` +} + +// Role represents a complete role resource +type Role struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` + Permissions []string `json:"permissions"` + Assignments []Assignment `json:"assignments,omitempty"` +} + +// RoleSummary represents a minimal role information +type RoleSummary struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitID string `json:"ouId"` +} + +// Link represents a pagination link +type Link struct { + Rel string `json:"rel"` + Href string `json:"href"` +} + +// RoleListResponse represents the paginated list of roles +type RoleListResponse struct { + TotalResults int `json:"totalResults"` + StartIndex int `json:"startIndex"` + Count int `json:"count"` + Links []Link `json:"links,omitempty"` + Roles []RoleSummary `json:"roles"` +} + +// AssignmentsRequest represents add/remove assignments request +type AssignmentsRequest struct { + Assignments []Assignment `json:"assignments"` +} + +// AssignmentListResponse represents the paginated list of assignments +type AssignmentListResponse struct { + TotalResults int `json:"totalResults"` + StartIndex int `json:"startIndex"` + Count int `json:"count"` + Links []Link `json:"links,omitempty"` + Assignments []Assignment `json:"assignments"` +} + +// ErrorResponse represents an error response from the API +type ErrorResponse struct { + Code string `json:"code"` + Message string `json:"message"` + Description string `json:"description,omitempty"` +} diff --git a/tests/integration/role/roleapi_test.go b/tests/integration/role/roleapi_test.go new file mode 100644 index 000000000..75efefcaf --- /dev/null +++ b/tests/integration/role/roleapi_test.go @@ -0,0 +1,952 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package role + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + + "github.com/asgardeo/thunder/tests/integration/testutils" + "github.com/stretchr/testify/suite" +) + +const ( + testServerURL = "https://localhost:8095" + rolesBasePath = "/roles" +) + +var ( + testOU = testutils.OrganizationUnit{ + Handle: "test-role-ou", + Name: "Test Organization Unit for Roles", + Description: "Organization unit created for role API testing", + Parent: nil, + } + + testUser1 = testutils.User{ + Type: "person", + Attributes: json.RawMessage(`{ + "email": "roleuser1@example.com", + "firstName": "Role", + "lastName": "User1", + "password": "TestPassword123!" + }`), + } + + testUser2 = testutils.User{ + Type: "person", + Attributes: json.RawMessage(`{ + "email": "roleuser2@example.com", + "firstName": "Role", + "lastName": "User2", + "password": "TestPassword123!" + }`), + } + + testGroup = testutils.Group{ + Name: "Test Role Group", + Description: "Group created for role API testing", + } + + testRole = CreateRoleRequest{ + Name: "Test Admin Role", + Description: "Admin role for testing", + Permissions: []string{"read:users", "write:users", "delete:users"}, + } +) + +var ( + testOUID string + testUserID1 string + testUserID2 string + testGroupID string + sharedRoleID string // Shared role created in SetupSuite for tests that need a pre-existing role +) + +type RoleAPITestSuite struct { + suite.Suite + client *http.Client +} + +func TestRoleAPITestSuite(t *testing.T) { + suite.Run(t, new(RoleAPITestSuite)) +} + +func (suite *RoleAPITestSuite) SetupSuite() { + // Create HTTP client that skips TLS verification for testing + suite.client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + + // Create test organization unit + ouID, err := testutils.CreateOrganizationUnit(testOU) + suite.Require().NoError(err, "Failed to create test organization unit") + testOUID = ouID + + // Create test users + user1 := testUser1 + user1.OrganizationUnit = testOUID + userID1, err := testutils.CreateUser(user1) + suite.Require().NoError(err, "Failed to create test user 1") + testUserID1 = userID1 + + user2 := testUser2 + user2.OrganizationUnit = testOUID + userID2, err := testutils.CreateUser(user2) + suite.Require().NoError(err, "Failed to create test user 2") + testUserID2 = userID2 + + // Create test group + groupToCreate := testGroup + groupToCreate.OrganizationUnitId = testOUID + groupID, err := testutils.CreateGroup(groupToCreate) + suite.Require().NoError(err, "Failed to create test group") + testGroupID = groupID + + // Create a shared role that can be used by multiple tests + sharedRole := testRole + sharedRole.OrganizationUnitID = testOUID + role, err := suite.createRole(sharedRole) + suite.Require().NoError(err, "Failed to create shared role") + sharedRoleID = role.ID +} + +func (suite *RoleAPITestSuite) TearDownSuite() { + // Cleanup in reverse order + if sharedRoleID != "" { + _ = suite.deleteRole(sharedRoleID) + } + if testGroupID != "" { + _ = testutils.DeleteGroup(testGroupID) + } + if testUserID2 != "" { + _ = testutils.DeleteUser(testUserID2) + } + if testUserID1 != "" { + _ = testutils.DeleteUser(testUserID1) + } + if testOUID != "" { + _ = testutils.DeleteOrganizationUnit(testOUID) + } +} + +// Test 1: Create Role +func (suite *RoleAPITestSuite) TestCreateRole_Success() { + roleRequest := CreateRoleRequest{ + Name: "Test Create Role Success", + Description: "Test role created in TestCreateRole_Success", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data", "write:data"}, + } + + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + suite.Require().NotNil(role) + + suite.NotEmpty(role.ID) + suite.Equal(roleRequest.Name, role.Name) + suite.Equal(roleRequest.Description, role.Description) + suite.Equal(roleRequest.OrganizationUnitID, role.OrganizationUnitID) + suite.Equal(len(roleRequest.Permissions), len(role.Permissions)) + + // Cleanup + _ = suite.deleteRole(role.ID) +} + +// Test 2: Create Role with Assignments +func (suite *RoleAPITestSuite) TestCreateRole_WithAssignments() { + roleRequest := CreateRoleRequest{ + Name: "Test Role With Assignments", + Description: "Role with initial assignments", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + suite.Require().NotNil(role) + + suite.Equal(1, len(role.Assignments)) + suite.Equal(testUserID1, role.Assignments[0].ID) + suite.Equal(AssigneeTypeUser, role.Assignments[0].Type) + + // Cleanup + _ = suite.deleteRole(role.ID) +} + +// Test 3: Create Role without Permissions +func (suite *RoleAPITestSuite) TestCreateRole_WithoutPermissions() { + roleRequest := CreateRoleRequest{ + Name: "Test Role Without Permissions", + Description: "Role without permissions", + OrganizationUnitID: testOUID, + Permissions: []string{}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + suite.Require().NotNil(role) + + suite.Equal(1, len(role.Assignments)) + suite.Equal(testUserID1, role.Assignments[0].ID) + suite.Equal(AssigneeTypeUser, role.Assignments[0].Type) + + // Cleanup + _ = suite.deleteRole(role.ID) +} + +// Test 4: Create Role - Validation Errors +func (suite *RoleAPITestSuite) TestCreateRole_ValidationErrors() { + testCases := []struct { + name string + roleRequest CreateRoleRequest + expectedErr string + }{ + { + name: "Missing Name", + roleRequest: CreateRoleRequest{ + OrganizationUnitID: testOUID, + Permissions: []string{"perm1"}, + }, + expectedErr: "ROL-1001", + }, + { + name: "Missing OrganizationUnitID", + roleRequest: CreateRoleRequest{ + Name: "Test Role", + Permissions: []string{"perm1"}, + }, + expectedErr: "ROL-1001", + }, + { + name: "Invalid Organization Unit", + roleRequest: CreateRoleRequest{ + Name: "Test Role", + OrganizationUnitID: "nonexistent-ou", + Permissions: []string{"perm1"}, + }, + expectedErr: "ROL-1005", + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + role, err := suite.createRole(tc.roleRequest) + suite.Error(err) + suite.Nil(role) + suite.Contains(err.Error(), tc.expectedErr) + }) + } +} + +// Test 5: Get Role +func (suite *RoleAPITestSuite) TestGetRole_Success() { + suite.Require().NotEmpty(sharedRoleID, "Shared role must be created in SetupSuite") + + role, err := suite.getRole(sharedRoleID) + suite.Require().NoError(err) + suite.Require().NotNil(role) + + suite.Equal(sharedRoleID, role.ID) + suite.Equal(testRole.Name, role.Name) + suite.Equal(testRole.Description, role.Description) +} + +// Test 6: Get Role - Not Found +func (suite *RoleAPITestSuite) TestGetRole_NotFound() { + role, err := suite.getRole("nonexistent-role-id") + suite.Error(err) + suite.Nil(role) + suite.Contains(err.Error(), "ROL-1003") +} + +// Test 7: List Roles +func (suite *RoleAPITestSuite) TestListRoles_Success() { + suite.Require().NotEmpty(sharedRoleID, "Shared role must be created in SetupSuite") + + response, err := suite.listRoles(0, 30) + suite.Require().NoError(err) + suite.Require().NotNil(response) + + suite.GreaterOrEqual(response.TotalResults, 1) + suite.GreaterOrEqual(response.Count, 1) + suite.NotEmpty(response.Roles) + + // Verify our shared role is in the list + found := false + for _, role := range response.Roles { + if role.ID == sharedRoleID { + found = true + suite.Equal(testRole.Name, role.Name) + break + } + } + suite.True(found, "Shared role should be in the list") +} + +// Test 8: List Roles - Pagination +func (suite *RoleAPITestSuite) TestListRoles_Pagination() { + // Create additional roles for pagination testing + role1Request := CreateRoleRequest{ + Name: "Pagination Test Role 1", + OrganizationUnitID: testOUID, + Permissions: []string{"perm1"}, + } + role2Request := CreateRoleRequest{ + Name: "Pagination Test Role 2", + OrganizationUnitID: testOUID, + Permissions: []string{"perm2"}, + } + + role1, err := suite.createRole(role1Request) + suite.Require().NoError(err) + defer suite.deleteRole(role1.ID) + + role2, err := suite.createRole(role2Request) + suite.Require().NoError(err) + defer suite.deleteRole(role2.ID) + + // Test pagination with limit + response, err := suite.listRoles(0, 2) + suite.Require().NoError(err) + suite.LessOrEqual(response.Count, 2) + + // Test with offset + response2, err := suite.listRoles(1, 2) + suite.Require().NoError(err) + suite.NotNil(response2) +} + +// Test 9: Update Role +func (suite *RoleAPITestSuite) TestUpdateRole_Success() { + suite.Require().NotEmpty(sharedRoleID, "Shared role must be created in SetupSuite") + + updateRequest := UpdateRoleRequest{ + Name: "Updated Admin Role", + Description: "Updated description", + OrganizationUnitID: testOUID, + Permissions: []string{"read:users", "write:users", "delete:users", "admin:all"}, + } + + role, err := suite.updateRole(sharedRoleID, updateRequest) + suite.Require().NoError(err) + suite.Require().NotNil(role) + + suite.Equal(sharedRoleID, role.ID) + suite.Equal(updateRequest.Name, role.Name) + suite.Equal(updateRequest.Description, role.Description) + suite.Equal(4, len(role.Permissions)) +} + +// Test 10: Update Role - Not Found +func (suite *RoleAPITestSuite) TestUpdateRole_NotFound() { + updateRequest := UpdateRoleRequest{ + Name: "Updated Role", + OrganizationUnitID: testOUID, + Permissions: []string{"perm1"}, + } + + role, err := suite.updateRole("nonexistent-role-id", updateRequest) + suite.Error(err) + suite.Nil(role) + suite.Contains(err.Error(), "ROL-1003") +} + +// Test 11: Add Assignments - User +func (suite *RoleAPITestSuite) TestAddAssignments_User() { + // Create a role for this test + roleRequest := CreateRoleRequest{ + Name: "Test Role for User Assignment", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + assignmentsRequest := AssignmentsRequest{ + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + + err = suite.addAssignments(role.ID, assignmentsRequest) + suite.Require().NoError(err) + + // Verify assignments were added + assignments, err := suite.getRoleAssignments(role.ID, 0, 30) + suite.Require().NoError(err) + suite.Equal(1, assignments.TotalResults) + suite.Equal(testUserID1, assignments.Assignments[0].ID) + suite.Equal(AssigneeTypeUser, assignments.Assignments[0].Type) +} + +// Test 12: Add Assignments - Group +func (suite *RoleAPITestSuite) TestAddAssignments_Group() { + // Create a role for this test + roleRequest := CreateRoleRequest{ + Name: "Test Role for Group Assignment", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + assignmentsRequest := AssignmentsRequest{ + Assignments: []Assignment{ + {ID: testGroupID, Type: AssigneeTypeGroup}, + }, + } + + err = suite.addAssignments(role.ID, assignmentsRequest) + suite.Require().NoError(err) + + // Verify assignments + assignments, err := suite.getRoleAssignments(role.ID, 0, 30) + suite.Require().NoError(err) + suite.Equal(1, assignments.TotalResults) // Group only + + // Check group assignment exists + groupFound := false + for _, assignment := range assignments.Assignments { + if assignment.ID == testGroupID && assignment.Type == AssigneeTypeGroup { + groupFound = true + break + } + } + suite.True(groupFound, "Group assignment should exist") +} + +// Test 13: Add Assignments - Multiple +func (suite *RoleAPITestSuite) TestAddAssignments_Multiple() { + // Create a new role for this test + roleRequest := CreateRoleRequest{ + Name: "Multi Assignment Role", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + assignmentsRequest := AssignmentsRequest{ + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + {ID: testUserID2, Type: AssigneeTypeUser}, + {ID: testGroupID, Type: AssigneeTypeGroup}, + }, + } + + err = suite.addAssignments(role.ID, assignmentsRequest) + suite.Require().NoError(err) + + // Verify all assignments + assignments, err := suite.getRoleAssignments(role.ID, 0, 30) + suite.Require().NoError(err) + suite.Equal(3, assignments.TotalResults) +} + +// Test 14: Add Assignments - Invalid User +func (suite *RoleAPITestSuite) TestAddAssignments_InvalidUser() { + // Create a role for this test + roleRequest := CreateRoleRequest{ + Name: "Test Role for Invalid Assignment", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + assignmentsRequest := AssignmentsRequest{ + Assignments: []Assignment{ + {ID: "nonexistent-user-id", Type: AssigneeTypeUser}, + }, + } + + err = suite.addAssignments(role.ID, assignmentsRequest) + suite.Error(err) + suite.Contains(err.Error(), "ROL-1007") +} + +// Test 15: Get Role Assignments +func (suite *RoleAPITestSuite) TestGetRoleAssignments_Success() { + // Create a role with an assignment for this test + roleRequest := CreateRoleRequest{ + Name: "Test Role for Get Assignments", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + assignments, err := suite.getRoleAssignments(role.ID, 0, 30) + suite.Require().NoError(err) + suite.Require().NotNil(assignments) + suite.GreaterOrEqual(assignments.TotalResults, 0) +} + +// Test 16: Get Role Assignments - Pagination +func (suite *RoleAPITestSuite) TestGetRoleAssignments_Pagination() { + // Create a role with multiple assignments for pagination testing + roleRequest := CreateRoleRequest{ + Name: "Test Role for Pagination", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + {ID: testUserID2, Type: AssigneeTypeUser}, + }, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + // Test with small page size + assignments, err := suite.getRoleAssignments(role.ID, 0, 1) + suite.Require().NoError(err) + suite.LessOrEqual(assignments.Count, 1) + + // Test with offset + if assignments.TotalResults > 1 { + assignments2, err := suite.getRoleAssignments(role.ID, 1, 1) + suite.Require().NoError(err) + suite.NotNil(assignments2) + } +} + +// Test 17: Remove Assignments +func (suite *RoleAPITestSuite) TestRemoveAssignments_Success() { + // Create a role with assignments for this test + roleRequest := CreateRoleRequest{ + Name: "Test Role for Remove Assignments", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + {ID: testUserID2, Type: AssigneeTypeUser}, + }, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + // Get current assignments + beforeAssignments, err := suite.getRoleAssignments(role.ID, 0, 30) + suite.Require().NoError(err) + initialCount := beforeAssignments.TotalResults + + suite.Require().Greater(initialCount, 0, "Should have assignments to remove") + + // Remove first assignment + assignmentToRemove := beforeAssignments.Assignments[0] + removeRequest := AssignmentsRequest{ + Assignments: []Assignment{assignmentToRemove}, + } + + err = suite.removeAssignments(role.ID, removeRequest) + suite.Require().NoError(err) + + // Verify assignment was removed + afterAssignments, err := suite.getRoleAssignments(role.ID, 0, 30) + suite.Require().NoError(err) + suite.Equal(initialCount-1, afterAssignments.TotalResults) +} + +// Test 18: Delete Role with Assignments +func (suite *RoleAPITestSuite) TestDeleteRole_WithAssignments() { + // Create a role with assignments + roleRequest := CreateRoleRequest{ + Name: "Role to Delete with Assignments", + OrganizationUnitID: testOUID, + Permissions: []string{"perm1"}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + + // Try to delete - should fail because it has assignments + err = suite.deleteRole(role.ID) + suite.Require().Error(err, "Delete should fail when role has assignments") + suite.Contains(err.Error(), "ROL-1006", "Should return cannot delete role error") + + // Remove assignments first + removeRequest := AssignmentsRequest{ + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + }, + } + err = suite.removeAssignments(role.ID, removeRequest) + suite.Require().NoError(err) + + // Now delete should succeed + err = suite.deleteRole(role.ID) + suite.NoError(err) +} + +// Test 19: Delete Role - Success +func (suite *RoleAPITestSuite) TestDeleteRole_Success() { + // Create a role without assignments + roleRequest := CreateRoleRequest{ + Name: "Role to Delete", + OrganizationUnitID: testOUID, + Permissions: []string{"perm1"}, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + + // Delete the role + err = suite.deleteRole(role.ID) + suite.NoError(err) + + // Verify role is deleted + deletedRole, err := suite.getRole(role.ID) + suite.Error(err) + suite.Nil(deletedRole) + suite.Contains(err.Error(), "ROL-1003") +} + +// Test 20: Delete Role - Not Found (Should return success for idempotency) +func (suite *RoleAPITestSuite) TestDeleteRole_NotFound() { + err := suite.deleteRole("nonexistent-role-id") + // As per service implementation, delete returns nil for non-existent roles + suite.NoError(err) +} + +// Test 21: Get Role Assignments with Display Names +func (suite *RoleAPITestSuite) TestGetRoleAssignments_WithDisplay() { + // Create a role with both user and group assignments + roleRequest := CreateRoleRequest{ + Name: "Test Role for Display Names", + OrganizationUnitID: testOUID, + Permissions: []string{"read:data"}, + Assignments: []Assignment{ + {ID: testUserID1, Type: AssigneeTypeUser}, + {ID: testGroupID, Type: AssigneeTypeGroup}, + }, + } + role, err := suite.createRole(roleRequest) + suite.Require().NoError(err) + defer suite.deleteRole(role.ID) + + // Get assignments without display parameter + assignmentsWithoutDisplay, err := suite.getRoleAssignmentsWithInclude(role.ID, 0, 30, "") + suite.Require().NoError(err) + suite.Require().NotNil(assignmentsWithoutDisplay) + suite.Equal(2, assignmentsWithoutDisplay.TotalResults) + + // Verify display names are not included + for _, assignment := range assignmentsWithoutDisplay.Assignments { + suite.Empty(assignment.Display, "Display field should be empty without include=display parameter") + } + + // Get assignments with include=display parameter + assignmentsWithDisplay, err := suite.getRoleAssignmentsWithInclude(role.ID, 0, 30, "display") + suite.Require().NoError(err) + suite.Require().NotNil(assignmentsWithDisplay) + suite.Equal(2, assignmentsWithDisplay.TotalResults) + + // Verify display names are included + userFound := false + groupFound := false + for _, assignment := range assignmentsWithDisplay.Assignments { + suite.NotEmpty(assignment.Display, "Display field should be populated with include=display parameter") + + if assignment.Type == AssigneeTypeUser && assignment.ID == testUserID1 { + userFound = true + // Display name for user should be the user ID (as per implementation) + suite.Equal(testUserID1, assignment.Display) + } + + if assignment.Type == AssigneeTypeGroup && assignment.ID == testGroupID { + groupFound = true + // Display name for group should be the group name + suite.Equal(testGroup.Name, assignment.Display) + } + } + + suite.True(userFound, "User assignment should be found") + suite.True(groupFound, "Group assignment should be found") +} + +// Helper methods + +func (suite *RoleAPITestSuite) createRole(request CreateRoleRequest) (*Role, error) { + body, err := json.Marshal(request) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", testServerURL+rolesBasePath, bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := suite.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusCreated { + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return nil, fmt.Errorf("failed to create role: %s - %s", errResp.Code, errResp.Message) + } + + var role Role + if err := json.Unmarshal(respBody, &role); err != nil { + return nil, err + } + + return &role, nil +} + +func (suite *RoleAPITestSuite) getRole(roleID string) (*Role, error) { + req, err := http.NewRequest("GET", fmt.Sprintf("%s%s/%s", testServerURL, rolesBasePath, roleID), nil) + if err != nil { + return nil, err + } + + resp, err := suite.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return nil, fmt.Errorf("failed to get role: %s - %s", errResp.Code, errResp.Message) + } + + var role Role + if err := json.Unmarshal(respBody, &role); err != nil { + return nil, err + } + + return &role, nil +} + +func (suite *RoleAPITestSuite) listRoles(offset, limit int) (*RoleListResponse, error) { + url := fmt.Sprintf("%s%s?offset=%d&limit=%d", testServerURL, rolesBasePath, offset, limit) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + resp, err := suite.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return nil, fmt.Errorf("failed to list roles: %s - %s", errResp.Code, errResp.Message) + } + + var response RoleListResponse + if err := json.Unmarshal(respBody, &response); err != nil { + return nil, err + } + + return &response, nil +} + +func (suite *RoleAPITestSuite) updateRole(roleID string, request UpdateRoleRequest) (*Role, error) { + body, err := json.Marshal(request) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", fmt.Sprintf("%s%s/%s", testServerURL, rolesBasePath, roleID), + bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := suite.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return nil, fmt.Errorf("failed to update role: %s - %s", errResp.Code, errResp.Message) + } + + var role Role + if err := json.Unmarshal(respBody, &role); err != nil { + return nil, err + } + + return &role, nil +} + +func (suite *RoleAPITestSuite) deleteRole(roleID string) error { + req, err := http.NewRequest("DELETE", fmt.Sprintf("%s%s/%s", testServerURL, rolesBasePath, roleID), nil) + if err != nil { + return err + } + + resp, err := suite.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + respBody, _ := io.ReadAll(resp.Body) + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return fmt.Errorf("failed to delete role: %s - %s", errResp.Code, errResp.Message) + } + + return nil +} + +func (suite *RoleAPITestSuite) addAssignments(roleID string, request AssignmentsRequest) error { + body, err := json.Marshal(request) + if err != nil { + return err + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s%s/%s/assignments/add", testServerURL, rolesBasePath, roleID), + bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := suite.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + respBody, _ := io.ReadAll(resp.Body) + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return fmt.Errorf("failed to add assignments: %s - %s", errResp.Code, errResp.Message) + } + + return nil +} + +func (suite *RoleAPITestSuite) removeAssignments(roleID string, request AssignmentsRequest) error { + body, err := json.Marshal(request) + if err != nil { + return err + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s%s/%s/assignments/remove", testServerURL, rolesBasePath, roleID), + bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := suite.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + respBody, _ := io.ReadAll(resp.Body) + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return fmt.Errorf("failed to remove assignments: %s - %s", errResp.Code, errResp.Message) + } + + return nil +} + +func (suite *RoleAPITestSuite) getRoleAssignments(roleID string, offset, limit int) (*AssignmentListResponse, error) { + return suite.getRoleAssignmentsWithInclude(roleID, offset, limit, "") +} + +func (suite *RoleAPITestSuite) getRoleAssignmentsWithInclude(roleID string, offset, limit int, + include string) (*AssignmentListResponse, error) { + url := fmt.Sprintf("%s%s/%s/assignments?offset=%d&limit=%d", testServerURL, rolesBasePath, roleID, offset, limit) + if include != "" { + url = fmt.Sprintf("%s&include=%s", url, include) + } + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + resp, err := suite.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + json.Unmarshal(respBody, &errResp) + return nil, fmt.Errorf("failed to get role assignments: %s - %s", errResp.Code, errResp.Message) + } + + var response AssignmentListResponse + if err := json.Unmarshal(respBody, &response); err != nil { + return nil, err + } + + return &response, nil +} diff --git a/tests/integration/testutils/apiutils.go b/tests/integration/testutils/apiutils.go index 309c22b1e..3e95e00aa 100644 --- a/tests/integration/testutils/apiutils.go +++ b/tests/integration/testutils/apiutils.go @@ -494,3 +494,61 @@ func FindUserByAttribute(key, value string) (*User, error) { } return nil, nil } + +// CreateGroup creates a group via API and returns the group ID +func CreateGroup(group Group) (string, error) { + groupJSON, err := json.Marshal(group) + if err != nil { + return "", fmt.Errorf("failed to marshal group: %w", err) + } + + req, err := http.NewRequest("POST", TestServerURL+"/groups", bytes.NewReader(groupJSON)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := getHTTPClient() + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("expected status 201, got %d. Response: %s", resp.StatusCode, string(bodyBytes)) + } + + var createdGroup map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&createdGroup) + if err != nil { + return "", fmt.Errorf("failed to parse response body: %w", err) + } + + groupID, ok := createdGroup["id"].(string) + if !ok { + return "", fmt.Errorf("response does not contain id") + } + return groupID, nil +} + +// DeleteGroup deletes a group by ID +func DeleteGroup(groupID string) error { + req, err := http.NewRequest("DELETE", TestServerURL+"/groups/"+groupID, nil) + if err != nil { + return fmt.Errorf("failed to create delete request: %w", err) + } + + client := getHTTPClient() + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + return fmt.Errorf("expected status 204 or 200, got %d", resp.StatusCode) + } + return nil +} diff --git a/tests/integration/testutils/models.go b/tests/integration/testutils/models.go index db2cbe035..1ddba3111 100644 --- a/tests/integration/testutils/models.go +++ b/tests/integration/testutils/models.go @@ -106,3 +106,11 @@ type AuthenticationResponse struct { OrganizationUnit string `json:"organization_unit"` Assertion string `json:"assertion,omitempty"` } + +// Group represents a group in the system +type Group struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + OrganizationUnitId string `json:"organizationUnitId,omitempty"` +}