diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d7564c3535e..095a108ff99 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -13,6 +13,8 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -29,7 +31,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -97,7 +98,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp t.Cleanup(ctrl.Finish) permissionsManagerMock := permissions.NewMockManager(ctrl) - peersmanager := peers.NewManager(store, permissionsManagerMock) + peersmanager := peers.NewManager(store) settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersmanager) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9fa4e51b26c..b13bfa7b4cf 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -25,6 +25,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/management-integrations/integrations" @@ -57,7 +58,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } permissionsManager := permissions.NewManager(store) - peersManager := peers.NewManager(store, permissionsManager) + peersManager := peers.NewManager(store) jobManager := job.NewJobManager(nil, store, peersManager) cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) diff --git a/client/server/server_test.go b/client/server/server_test.go index 77299757566..6627ba875e7 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" @@ -38,7 +40,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -305,7 +306,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve t.Cleanup(ctrl.Finish) permissionsManagerMock := permissions.NewMockManager(ctrl) - peersManager := peers.NewManager(store, permissionsManagerMock) + peersManager := peers.NewManager(store) settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersManager) diff --git a/go.mod b/go.mod index 1b5861a378e..3691b5d1492 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 diff --git a/go.sum b/go.sum index 3772946e1c4..4f7c7b803b6 100644 --- a/go.sum +++ b/go.sum @@ -453,8 +453,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index d3f8f44ff6d..23e5330b9b9 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -16,9 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" ) @@ -38,17 +35,15 @@ type Manager interface { type managerImpl struct { store store.Store - permissionsManager permissions.Manager integratedPeerValidator integrated_validator.IntegratedValidator accountManager account.Manager networkMapController network_map.Controller } -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, + store: store, } } @@ -65,28 +60,10 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) { } func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) } func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - } - return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") } diff --git a/management/server/permissions/manager.go b/management/internals/modules/permissions/manager.go similarity index 52% rename from management/server/permissions/manager.go rename to management/internals/modules/permissions/manager.go index e6bdd20259d..99558da0255 100644 --- a/management/server/permissions/manager.go +++ b/management/internals/modules/permissions/manager.go @@ -4,20 +4,29 @@ package permissions import ( "context" + "net/http" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/roles" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/permissions/roles" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) +// AuthErrorHandler is called when an auth error occurs during permission validation. +// If it returns true, the error is considered handled and the default error response is skipped. +type AuthErrorHandler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool + type Manager interface { + WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error @@ -36,6 +45,51 @@ func NewManager(store store.Store) Manager { } } +// WithPermission wraps an HTTP handler with permission checking logic. +// An optional AuthErrorHandler can be provided to intercept auth errors before the default response is written. +func (m *managerImpl) WithPermission( + module modules.Module, + operation operations.Operation, + handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth), + authErrHandler ...AuthErrorHandler, +) http.HandlerFunc { + var onAuthErr AuthErrorHandler + if len(authErrHandler) > 0 { + onAuthErr = authErrHandler[0] + } + + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err) + util.WriteError(r.Context(), err, w) + return + } + + allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation) + if err != nil { + if onAuthErr != nil && onAuthErr(w, r, &userAuth, err) { + return + } + log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err) + util.WriteError(r.Context(), status.NewPermissionValidationError(err), w) + return + } + + if !allowed { + permErr := status.NewPermissionDeniedError() + if onAuthErr != nil && onAuthErr(w, r, &userAuth, permErr) { + return + } + log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module) + util.WriteError(r.Context(), permErr, w) + return + } + + handlerFunc(w, r, &userAuth) + } +} + func (m *managerImpl) ValidateUserPermissions( ctx context.Context, accountID string, @@ -68,10 +122,6 @@ func (m *managerImpl) ValidateUserPermissions( return false, err } - if operation == operations.Read && user.IsServiceUser { - return true, nil // this should be replaced by proper granular access role - } - role, ok := roles.RolesMap[user.Role] if !ok { return false, status.NewUserRoleNotFoundError(string(user.Role)) @@ -127,3 +177,17 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR func (m *managerImpl) SetAccountManager(accountManager account.Manager) { // no-op } + +// WrapHandler wraps a handler that expects UserAuth with context extraction. +// Unlike WithPermission, it does not perform any permission checks. +func WrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err) + util.WriteError(r.Context(), err, w) + return + } + h(w, r, &userAuth) + } +} diff --git a/management/server/permissions/manager_mock.go b/management/internals/modules/permissions/manager_mock.go similarity index 75% rename from management/server/permissions/manager_mock.go rename to management/internals/modules/permissions/manager_mock.go index ec9f263f965..64f0a508ff5 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/internals/modules/permissions/manager_mock.go @@ -5,15 +5,18 @@ package permissions import ( - context "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - account "github.com/netbirdio/netbird/management/server/account" - modules "github.com/netbirdio/netbird/management/server/permissions/modules" - operations "github.com/netbirdio/netbird/management/server/permissions/operations" - roles "github.com/netbirdio/netbird/management/server/permissions/roles" - types "github.com/netbirdio/netbird/management/server/types" + "context" + "net/http" + "reflect" + + "github.com/golang/mock/gomock" + + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/roles" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" ) // MockManager is a mock of Manager interface. @@ -108,3 +111,22 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation) } + +// WithPermission mocks base method. +func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc { + m.ctrl.T.Helper() + varargs := []interface{}{module, operation, handlerFunc} + for _, a := range authErrHandler { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WithPermission", varargs...) + ret0, _ := ret[0].(http.HandlerFunc) + return ret0 +} + +// WithPermission indicates an expected call of WithPermission. +func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}, authErrHandler ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{module, operation, handlerFunc}, authErrHandler...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), varargs...) +} diff --git a/management/server/permissions/modules/module.go b/management/internals/modules/permissions/modules/module.go similarity index 100% rename from management/server/permissions/modules/module.go rename to management/internals/modules/permissions/modules/module.go diff --git a/management/server/permissions/operations/operation.go b/management/internals/modules/permissions/operations/operation.go similarity index 100% rename from management/server/permissions/operations/operation.go rename to management/internals/modules/permissions/operations/operation.go diff --git a/management/server/permissions/roles/admin.go b/management/internals/modules/permissions/roles/admin.go similarity index 69% rename from management/server/permissions/roles/admin.go rename to management/internals/modules/permissions/roles/admin.go index af3a81297c6..5e3e5cc9569 100644 --- a/management/server/permissions/roles/admin.go +++ b/management/internals/modules/permissions/roles/admin.go @@ -1,8 +1,8 @@ package roles import ( - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/types" ) @@ -18,7 +18,7 @@ var Admin = RolePermissions{ modules.Accounts: { operations.Read: true, operations.Create: false, - operations.Update: false, + operations.Update: true, operations.Delete: false, }, }, diff --git a/management/server/permissions/roles/auditor.go b/management/internals/modules/permissions/roles/auditor.go similarity index 78% rename from management/server/permissions/roles/auditor.go rename to management/internals/modules/permissions/roles/auditor.go index 33d8651f4ef..7d762e2fa22 100644 --- a/management/server/permissions/roles/auditor.go +++ b/management/internals/modules/permissions/roles/auditor.go @@ -1,7 +1,7 @@ package roles import ( - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/types" ) diff --git a/management/server/permissions/roles/network_admin.go b/management/internals/modules/permissions/roles/network_admin.go similarity index 93% rename from management/server/permissions/roles/network_admin.go rename to management/internals/modules/permissions/roles/network_admin.go index 8f69d46ad68..23afb618167 100644 --- a/management/server/permissions/roles/network_admin.go +++ b/management/internals/modules/permissions/roles/network_admin.go @@ -1,8 +1,8 @@ package roles import ( - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/types" ) diff --git a/management/server/permissions/roles/owner.go b/management/internals/modules/permissions/roles/owner.go similarity index 78% rename from management/server/permissions/roles/owner.go rename to management/internals/modules/permissions/roles/owner.go index 668470e47e5..d1eb4c70c2b 100644 --- a/management/server/permissions/roles/owner.go +++ b/management/internals/modules/permissions/roles/owner.go @@ -1,7 +1,7 @@ package roles import ( - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/types" ) diff --git a/management/server/permissions/roles/role_permissions.go b/management/internals/modules/permissions/roles/role_permissions.go similarity index 76% rename from management/server/permissions/roles/role_permissions.go rename to management/internals/modules/permissions/roles/role_permissions.go index 754e568f579..296acf04a56 100644 --- a/management/server/permissions/roles/role_permissions.go +++ b/management/internals/modules/permissions/roles/role_permissions.go @@ -1,8 +1,8 @@ package roles import ( - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/types" ) diff --git a/management/server/permissions/roles/user.go b/management/internals/modules/permissions/roles/user.go similarity index 78% rename from management/server/permissions/roles/user.go rename to management/internals/modules/permissions/roles/user.go index bb3df0aeab5..1df61ff6e1b 100644 --- a/management/server/permissions/roles/user.go +++ b/management/internals/modules/permissions/roles/user.go @@ -1,7 +1,7 @@ package roles import ( - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/types" ) diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/api.go b/management/internals/modules/reverseproxy/accesslogs/manager/api.go index 1e1414ca505..4710fe68077 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/api.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/api.go @@ -5,8 +5,11 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -15,21 +18,15 @@ type handler struct { manager accesslogs.Manager } -func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) { +func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS") + router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS") } -func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var filter accesslogs.AccessLogFilter filter.ParseFromRequest(r) diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index 59d7704ebdd..8c1af1235cd 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -9,25 +9,19 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" ) type managerImpl struct { - store store.Store - permissionsManager permissions.Manager - geo geolocation.Geolocation - cleanupCancel context.CancelFunc + store store.Store + geo geolocation.Geolocation + cleanupCancel context.CancelFunc } -func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager { +func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - geo: geo, + store: store, + geo: geo, } } @@ -63,14 +57,6 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac // GetAllAccessLogs retrieves access logs for an account with pagination and filtering func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, 0, status.NewPermissionValidationError(err) - } - if !ok { - return nil, 0, status.NewPermissionDeniedError() - } - if err := m.resolveUserFilters(ctx, accountID, filter); err != nil { log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err) } diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 4493ef0ad66..bde36aaa9da 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -6,8 +6,11 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -17,15 +20,15 @@ type handler struct { manager Manager } -func RegisterEndpoints(router *mux.Router, manager Manager) { +func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS") - router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS") - router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS") - router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS") + router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS") + router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS") + router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST } func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { @@ -56,13 +59,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain { return resp } -func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -77,13 +74,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, ret) } -func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiReverseProxiesDomainsJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -99,13 +90,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, domainToApi(domain)) } -func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { domainID := mux.Vars(r)["domainId"] if domainID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) @@ -120,13 +105,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } -func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { domainID := mux.Vars(r)["domainId"] if domainID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 2c4c1372e21..6bf04ec35ca 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -11,11 +11,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/status" ) type store interface { @@ -37,32 +33,22 @@ type proxyManager interface { } type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - permissionsManager permissions.Manager - accountManager account.Manager + store store + validator domain.Validator + proxyManager proxyManager + accountManager account.Manager } -func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager { +func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager { return Manager{ - store: store, - proxyManager: proxyMgr, - validator: domain.Validator{Resolver: net.DefaultResolver}, - permissionsManager: permissionsManager, - accountManager: accountManager, + store: store, + proxyManager: proxyMgr, + validator: domain.Validator{Resolver: net.DefaultResolver}, + accountManager: accountManager, } } func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - domains, err := m.store.ListCustomDomains(ctx, accountID) if err != nil { return nil, fmt.Errorf("list custom domains: %w", err) @@ -118,14 +104,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d } func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - // Verify the target cluster is in the available clusters allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) if err != nil { @@ -159,14 +137,6 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName } func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - d, err := m.store.GetCustomDomain(ctx, accountID, domainID) if err != nil { return fmt.Errorf("get domain from store: %w", err) @@ -183,21 +153,6 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s } func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) - if err != nil { - log.WithFields(log.Fields{ - "accountID": accountID, - "domainID": domainID, - }).WithError(err).Error("validate domain") - return - } - if !ok { - log.WithFields(log.Fields{ - "accountID": accountID, - "domainID": domainID, - }).WithError(err).Error("validate domain") - } - log.WithFields(log.Fields{ "accountID": accountID, "domainID": domainID, diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 339c82446ca..5d04d8e47b4 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -23,7 +23,7 @@ type Proxy struct { LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time - Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index cd81efa88dd..4b02cfb1edb 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -6,12 +6,14 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" - nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -30,25 +32,19 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma } domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() - domainmanager.RegisterEndpoints(domainRouter, domainManager) + domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager) - accesslogsmanager.RegisterEndpoints(router, accessLogsManager) + accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager) - router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") - router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") - router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") - router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") - router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS") - router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS") + router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS") + router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS") } -func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -63,13 +59,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiServices) } -func (h *handler) createService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.ServiceRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -77,12 +67,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) { } service := new(rpservice.Service) + var err error if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } - if err = service.Validate(); err != nil { + if err := service.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -96,13 +87,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse()) } -func (h *handler) getService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { serviceID := mux.Vars(r)["serviceId"] if serviceID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) @@ -118,13 +103,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, service.ToAPIResponse()) } -func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { serviceID := mux.Vars(r)["serviceId"] if serviceID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) @@ -139,12 +118,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { service := new(rpservice.Service) service.ID = serviceID + var err error if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } - if err = service.Validate(); err != nil { + if err := service.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -158,13 +138,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse()) } -func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { serviceID := mux.Vars(r)["serviceId"] if serviceID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) @@ -179,13 +153,7 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index 28461641dda..c9976bf0e45 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -15,7 +15,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" ) @@ -86,18 +85,17 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, - GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, } mgr := &Manager{ - store: testStore, - accountManager: accountMgr, - permissionsManager: permissions.NewManager(testStore), - proxyController: mockCtrl, - capabilities: mockCaps, - clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, + store: testStore, + accountManager: accountMgr, + proxyController: mockCtrl, + capabilities: mockCaps, + clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, } mgr.exposeReaper = &exposeReaper{manager: mgr} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ed9d4201be2..78cd927583d 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -21,9 +21,6 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" ) @@ -82,24 +79,22 @@ type CapabilityProvider interface { } type Manager struct { - store store.Store - accountManager account.Manager - permissionsManager permissions.Manager - proxyController proxy.Controller - capabilities CapabilityProvider - clusterDeriver ClusterDeriver - exposeReaper *exposeReaper + store store.Store + accountManager account.Manager + proxyController proxy.Controller + capabilities CapabilityProvider + clusterDeriver ClusterDeriver + exposeReaper *exposeReaper } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { +func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { mgr := &Manager{ - store: store, - accountManager: accountManager, - permissionsManager: permissionsManager, - proxyController: proxyController, - capabilities: capabilities, - clusterDeriver: clusterDeriver, + store: store, + accountManager: accountManager, + proxyController: proxyController, + capabilities: capabilities, + clusterDeriver: clusterDeriver, } mgr.exposeReaper = &exposeReaper{manager: mgr} return mgr @@ -112,26 +107,10 @@ func (m *Manager) StartExposeReaper(ctx context.Context) { // GetActiveClusters returns all active proxy clusters with their connected proxy count. func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetActiveProxyClusters(ctx) } func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) @@ -185,14 +164,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s * } func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return nil, fmt.Errorf("failed to get service: %w", err) @@ -206,14 +177,6 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s } func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil { return nil, err } @@ -224,7 +187,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta()) - err = m.replaceHostByLookup(ctx, accountID, s) + err := m.replaceHostByLookup(ctx, accountID, s) if err != nil { return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } @@ -491,14 +454,6 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St } func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - if err := service.Auth.HashSecrets(); err != nil { return nil, fmt.Errorf("hash secrets: %w", err) } @@ -785,16 +740,8 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes. } func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - var s *service.Service - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var err error s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { @@ -825,16 +772,8 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI } func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - var services []*service.Service - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var err error services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID) if err != nil { @@ -1119,7 +1058,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr } groupIDs := make([]string, 0, len(groupNames)) for _, groupName := range groupNames { - g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator) + g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID) if err != nil { return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 54ac8ab182e..2147574d7c6 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -23,9 +23,6 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" @@ -700,12 +697,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) require.NoError(t, err) - permsMgr := permissions.NewManager(testStore) - accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, - GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, } @@ -718,10 +713,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { require.NoError(t, err) mgr := &Manager{ - store: testStore, - accountManager: accountMgr, - permissionsManager: permsMgr, - proxyController: proxyController, + store: testStore, + accountManager: accountMgr, + proxyController: proxyController, clusterDeriver: &testClusterDeriver{ domains: []string{"test.netbird.io"}, }, @@ -1130,7 +1124,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockPerms := permissions.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) @@ -1141,10 +1134,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) { require.NoError(t, err) mgr := &Manager{ - store: sqlStore, - permissionsManager: mockPerms, - accountManager: mockAcct, - proxyController: proxyController, + store: sqlStore, + accountManager: mockAcct, + proxyController: proxyController, } service := &rpservice.Service{ @@ -1167,9 +1159,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) { require.NoError(t, err) require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion") - mockPerms.EXPECT(). - ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete). - Return(true, nil) mockAcct.EXPECT(). StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) mockAcct.EXPECT(). diff --git a/management/internals/modules/zones/manager/api.go b/management/internals/modules/zones/manager/api.go index 919d77d61d4..0fddb45a940 100644 --- a/management/internals/modules/zones/manager/api.go +++ b/management/internals/modules/zones/manager/api.go @@ -6,8 +6,11 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/zones" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -17,25 +20,19 @@ type handler struct { manager zones.Manager } -func RegisterEndpoints(router *mux.Router, manager zones.Manager) { +func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS") + router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS") } -func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiZones) } -func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiDnsZonesJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { zone := new(zones.Zone) zone.FromAPIRequest(&req) - if err = zone.Validate(); err != nil { + if err := zone.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse()) } -func (h *handler) getZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse()) } -func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { } var req api.PutApiDnsZonesZoneIdJSONRequestBody - if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } @@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { zone.FromAPIRequest(&req) zone.ID = zoneID - if err = zone.Validate(); err != nil { + if err := zone.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse()) } -func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) return } - if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { + if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/internals/modules/zones/manager/manager.go b/management/internals/modules/zones/manager/manager.go index 8548dd48cee..6b8f5fbb9b7 100644 --- a/management/internals/modules/zones/manager/manager.go +++ b/management/internals/modules/zones/manager/manager.go @@ -7,62 +7,34 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" ) type managerImpl struct { - store store.Store - accountManager account.Manager - permissionsManager permissions.Manager - dnsDomain string + store store.Store + accountManager account.Manager + dnsDomain string } -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager { +func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager { return &managerImpl{ - store: store, - accountManager: accountManager, - permissionsManager: permissionsManager, - dnsDomain: dnsDomain, + store: store, + accountManager: accountManager, + dnsDomain: dnsDomain, } } func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID) } func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID) } func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - + var err error if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil { return nil, err } @@ -102,14 +74,6 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, } func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID) if err != nil { return nil, fmt.Errorf("failed to get zone: %w", err) @@ -150,14 +114,6 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, } func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) if err != nil { return fmt.Errorf("failed to get zone: %w", err) diff --git a/management/internals/modules/zones/manager/manager_test.go b/management/internals/modules/zones/manager/manager_test.go index b45ec787417..df39f8a8d10 100644 --- a/management/internals/modules/zones/manager/manager_test.go +++ b/management/internals/modules/zones/manager/manager_test.go @@ -13,9 +13,6 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" @@ -29,7 +26,7 @@ const ( testDNSDomain = "netbird.selfhosted" ) -func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) { +func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) { t.Helper() ctx := context.Background() @@ -49,23 +46,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun ctrl := gomock.NewController(t) mockAccountManager := &mock_server.MockAccountManager{} - mockPermissionsManager := permissions.NewMockManager(ctrl) - manager := &managerImpl{ - store: testStore, - accountManager: mockAccountManager, - permissionsManager: mockPermissionsManager, - dnsDomain: testDNSDomain, - } + manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl) - return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup + return manager, testStore, mockAccountManager, ctrl, cleanup } func TestManagerImpl_GetAllZones(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -77,10 +68,6 @@ func TestManagerImpl_GetAllZones(t *testing.T) { err = testStore.CreateZone(ctx, zone2) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(true, nil) - result, err := manager.GetAllZones(ctx, testAccountID, testUserID) require.NoError(t, err) assert.Len(t, result, 2) @@ -88,43 +75,13 @@ func TestManagerImpl_GetAllZones(t *testing.T) { assert.Equal(t, zone2.ID, result[1].ID) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(false, nil) - - result, err := manager.GetAllZones(ctx, testAccountID, testUserID) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - - t.Run("permission validation error", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(false, status.Errorf(status.Internal, "permission check failed")) - - result, err := manager.GetAllZones(ctx, testAccountID, testUserID) - require.Error(t, err) - assert.Nil(t, result) - }) } func TestManagerImpl_GetZone(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -132,10 +89,6 @@ func TestManagerImpl_GetZone(t *testing.T) { err := testStore.CreateZone(ctx, zone) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(true, nil) - result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID) require.NoError(t, err) assert.Equal(t, zone.ID, result.ID) @@ -143,29 +96,13 @@ func TestManagerImpl_GetZone(t *testing.T) { assert.Equal(t, zone.Domain, result.Domain) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(false, nil) - - result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) } func TestManagerImpl_CreateZone(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -177,10 +114,6 @@ func TestManagerImpl_CreateZone(t *testing.T) { DistributionGroups: []string{testGroupID}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testAccountID, accountID) @@ -199,31 +132,8 @@ func TestManagerImpl_CreateZone(t *testing.T) { assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - inputZone := &zones.Zone{ - Name: "New Zone", - Domain: "new.example.com", - DistributionGroups: []string{testGroupID}, - } - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(false, nil) - - result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - t.Run("invalid group", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -233,17 +143,13 @@ func TestManagerImpl_CreateZone(t *testing.T) { DistributionGroups: []string{"invalid-group"}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) require.Error(t, err) assert.Nil(t, result) }) t.Run("duplicate domain", func(t *testing.T) { - manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -259,10 +165,6 @@ func TestManagerImpl_CreateZone(t *testing.T) { DistributionGroups: []string{testGroupID}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) require.Error(t, err) assert.Nil(t, result) @@ -273,7 +175,7 @@ func TestManagerImpl_CreateZone(t *testing.T) { }) t.Run("peer DNS domain conflict", func(t *testing.T) { - manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -291,10 +193,6 @@ func TestManagerImpl_CreateZone(t *testing.T) { DistributionGroups: []string{testGroupID}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) require.Error(t, err) assert.Nil(t, result) @@ -305,7 +203,7 @@ func TestManagerImpl_CreateZone(t *testing.T) { }) t.Run("default DNS domain conflict", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -317,10 +215,6 @@ func TestManagerImpl_CreateZone(t *testing.T) { DistributionGroups: []string{testGroupID}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) require.Error(t, err) assert.Nil(t, result) @@ -335,7 +229,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -352,10 +246,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) { DistributionGroups: []string{testGroupID}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - storeEventCalled := false mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { storeEventCalled = true @@ -375,7 +265,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) { }) t.Run("domain change not allowed", func(t *testing.T) { - manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -392,10 +282,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) { DistributionGroups: []string{testGroupID}, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) require.Error(t, err) assert.Nil(t, result) @@ -405,31 +291,8 @@ func TestManagerImpl_UpdateZone(t *testing.T) { assert.Equal(t, status.InvalidArgument, s.Type()) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - updatedZone := &zones.Zone{ - ID: testZoneID, - Name: "Updated Name", - Domain: "example.com", - } - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(false, nil) - - result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - t.Run("zone not found", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -439,10 +302,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) { Domain: "example.com", } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) require.Error(t, err) assert.Nil(t, result) @@ -453,7 +312,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) { ctx := context.Background() t.Run("success with records", func(t *testing.T) { - manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -469,10 +328,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) { err = testStore.CreateDNSRecord(ctx, record2) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(true, nil) - storeEventCallCount := 0 mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { storeEventCallCount++ @@ -493,7 +348,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) { }) t.Run("success without records", func(t *testing.T) { - manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -501,10 +356,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) { err := testStore.CreateZone(ctx, zone) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(true, nil) - storeEventCalled := false mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { storeEventCalled = true @@ -522,31 +373,11 @@ func TestManagerImpl_DeleteZone(t *testing.T) { require.Error(t, err) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(false, nil) - - err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID) - require.Error(t, err) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - t.Run("zone not found", func(t *testing.T) { - manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(true, nil) - err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone") require.Error(t, err) }) diff --git a/management/internals/modules/zones/records/manager/api.go b/management/internals/modules/zones/records/manager/api.go index f8ecfef7d55..aa746799bb6 100644 --- a/management/internals/modules/zones/records/manager/api.go +++ b/management/internals/modules/zones/records/manager/api.go @@ -6,8 +6,11 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/internals/modules/zones/records" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -17,25 +20,19 @@ type handler struct { manager records.Manager } -func RegisterEndpoints(router *mux.Router, manager records.Manager) { +func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS") } -func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRecords) } -func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { record := new(records.Record) record.FromAPIRequest(&req) - if err = record.Validate(); err != nil { + if err := record.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse()) } -func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, record.ToAPIResponse()) } -func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { } var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody - if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } @@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { record.FromAPIRequest(&req) record.ID = recordID - if err = record.Validate(); err != nil { + if err := record.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse()) } -func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) { return } - if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { + if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/internals/modules/zones/records/manager/manager.go b/management/internals/modules/zones/records/manager/manager.go index 5374a2ef2a3..dad8a8a1499 100644 --- a/management/internals/modules/zones/records/manager/manager.go +++ b/management/internals/modules/zones/records/manager/manager.go @@ -9,64 +9,36 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" ) type managerImpl struct { - store store.Store - accountManager account.Manager - permissionsManager permissions.Manager + store store.Store + accountManager account.Manager } -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager { +func NewManager(store store.Store, accountManager account.Manager) records.Manager { return &managerImpl{ - store: store, - accountManager: accountManager, - permissionsManager: permissionsManager, + store: store, + accountManager: accountManager, } } func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID) } func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID) } func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - var zone *zones.Zone record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL) - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) if err != nil { return fmt.Errorf("failed to get zone: %w", err) @@ -101,18 +73,11 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI } func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - var zone *zones.Zone var record *records.Record - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) if err != nil { return fmt.Errorf("failed to get zone: %w", err) @@ -160,18 +125,11 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI } func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - var record *records.Record var zone *zones.Zone - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) if err != nil { return fmt.Errorf("failed to get zone: %w", err) diff --git a/management/internals/modules/zones/records/manager/manager_test.go b/management/internals/modules/zones/records/manager/manager_test.go index 0a962e0f4ac..db040c13c47 100644 --- a/management/internals/modules/zones/records/manager/manager_test.go +++ b/management/internals/modules/zones/records/manager/manager_test.go @@ -12,12 +12,8 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -27,7 +23,7 @@ const ( testGroupID = "test-group-id" ) -func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) { +func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) { t.Helper() ctx := context.Background() @@ -51,22 +47,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv ctrl := gomock.NewController(t) mockAccountManager := &mock_server.MockAccountManager{} - mockPermissionsManager := permissions.NewMockManager(ctrl) - manager := &managerImpl{ - store: testStore, - accountManager: mockAccountManager, - permissionsManager: mockPermissionsManager, - } + manager := NewManager(testStore, mockAccountManager).(*managerImpl) - return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup + return manager, testStore, zone, mockAccountManager, ctrl, cleanup } func TestManagerImpl_GetAllRecords(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -78,10 +69,6 @@ func TestManagerImpl_GetAllRecords(t *testing.T) { err = testStore.CreateDNSRecord(ctx, record2) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(true, nil) - result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) require.NoError(t, err) assert.Len(t, result, 2) @@ -89,43 +76,13 @@ func TestManagerImpl_GetAllRecords(t *testing.T) { assert.Equal(t, record2.ID, result[1].ID) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(false, nil) - - result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - - t.Run("permission validation error", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(false, status.Errorf(status.Internal, "permission check failed")) - - result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) - require.Error(t, err) - assert.Nil(t, result) - }) } func TestManagerImpl_GetRecord(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -133,10 +90,6 @@ func TestManagerImpl_GetRecord(t *testing.T) { err := testStore.CreateDNSRecord(ctx, record) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(true, nil) - result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID) require.NoError(t, err) assert.Equal(t, record.ID, result.ID) @@ -146,29 +99,13 @@ func TestManagerImpl_GetRecord(t *testing.T) { assert.Equal(t, record.TTL, result.TTL) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). - Return(false, nil) - - result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) } func TestManagerImpl_CreateRecord(t *testing.T) { ctx := context.Background() t.Run("success - A record", func(t *testing.T) { - manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -179,10 +116,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) { TTL: 300, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testAccountID, accountID) @@ -202,7 +135,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) { }) t.Run("success - AAAA record", func(t *testing.T) { - manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -213,10 +146,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) { TTL: 600, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testAccountID, accountID) @@ -231,7 +160,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) { }) t.Run("success - CNAME record", func(t *testing.T) { - manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -242,10 +171,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) { TTL: 300, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testAccountID, accountID) @@ -259,32 +184,8 @@ func TestManagerImpl_CreateRecord(t *testing.T) { assert.Equal(t, inputRecord.Content, result.Content) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - inputRecord := &records.Record{ - Name: "api.example.com", - Type: records.RecordTypeA, - Content: "192.168.1.1", - TTL: 300, - } - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(false, nil) - - result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - t.Run("record name not in zone", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -295,10 +196,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) { TTL: 300, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) require.Error(t, err) assert.Nil(t, result) @@ -306,7 +203,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) { }) t.Run("duplicate record", func(t *testing.T) { - manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -321,10 +218,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) { TTL: 300, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) require.Error(t, err) assert.Nil(t, result) @@ -332,7 +225,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) { }) t.Run("CNAME conflict with existing A record", func(t *testing.T) { - manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -347,10 +240,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) { TTL: 300, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). - Return(true, nil) - result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) require.Error(t, err) assert.Nil(t, result) @@ -362,7 +251,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -378,10 +267,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { TTL: 600, // Changed TTL } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - storeEventCalled := false mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { storeEventCalled = true @@ -400,7 +285,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { }) t.Run("update only TTL - no validation", func(t *testing.T) { - manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -416,10 +301,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { TTL: 600, // Only TTL changed } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { // Event should be stored } @@ -430,33 +311,8 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { assert.Equal(t, 600, result.TTL) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - updatedRecord := &records.Record{ - ID: testRecordID, - Name: "api.example.com", - Type: records.RecordTypeA, - Content: "192.168.1.100", - TTL: 600, - } - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(false, nil) - - result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) - require.Error(t, err) - assert.Nil(t, result) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - t.Run("record not found", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -468,17 +324,13 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { TTL: 600, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) require.Error(t, err) assert.Nil(t, result) }) t.Run("update creates duplicate", func(t *testing.T) { - manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -498,10 +350,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) { TTL: 300, } - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). - Return(true, nil) - result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) require.Error(t, err) assert.Nil(t, result) @@ -513,7 +361,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { - manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() @@ -521,10 +369,6 @@ func TestManagerImpl_DeleteRecord(t *testing.T) { err := testStore.CreateDNSRecord(ctx, record) require.NoError(t, err) - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(true, nil) - storeEventCalled := false mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { storeEventCalled = true @@ -542,31 +386,11 @@ func TestManagerImpl_DeleteRecord(t *testing.T) { require.Error(t, err) }) - t.Run("permission denied", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) - defer cleanup() - defer ctrl.Finish() - - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(false, nil) - - err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID) - require.Error(t, err) - s, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, status.PermissionDenied, s.Type()) - }) - t.Run("record not found", func(t *testing.T) { - manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + manager, _, zone, _, ctrl, cleanup := setupTest(t) defer cleanup() defer ctrl.Finish() - mockPermissionsManager.EXPECT(). - ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). - Return(true, nil) - err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record") require.Error(t, err) }) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 2b40c0aad9c..0279a913182 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -233,7 +233,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { func (s *BaseServer) AccessLogsManager() accesslogs.Manager { return Create(s, func() accesslogs.Manager { - accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) + accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager()) accessLogManager.StartPeriodicCleanup( context.Background(), s.Config.ReverseProxy.AccessLogRetentionDays, diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 9b2ec298952..8dc736611ee 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -27,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/users" ) @@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager { idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled } - return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig) + return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig) }) } func (s *BaseServer) PeersManager() peers.Manager { return Create(s, func() peers.Manager { - manager := peers.NewManager(s.Store(), s.PermissionsManager()) + manager := peers.NewManager(s.Store()) s.AfterInit(func(s *BaseServer) { manager.SetNetworkMapController(s.NetworkMapController()) manager.SetIntegratedPeerValidator(s.IntegratedValidator()) @@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { func (s *BaseServer) GroupsManager() groups.Manager { return Create(s, func() groups.Manager { - return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager()) + return groups.NewManager(s.Store(), s.AccountManager()) }) } func (s *BaseServer) ResourcesManager() resources.Manager { return Create(s, func() resources.Manager { - return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager()) + return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager()) }) } func (s *BaseServer) RoutesManager() routers.Manager { return Create(s, func() routers.Manager { - return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager()) + return routers.NewManager(s.Store(), s.AccountManager()) }) } func (s *BaseServer) NetworksManager() networks.Manager { return Create(s, func() networks.Manager { - return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager()) + return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager()) }) } func (s *BaseServer) ZonesManager() zones.Manager { return Create(s, func() zones.Manager { - return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain()) + return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain()) }) } func (s *BaseServer) RecordsManager() records.Manager { return Create(s, func() records.Manager { - return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager()) + return recordsManager.NewManager(s.Store(), s.AccountManager()) }) } func (s *BaseServer) ServiceManager() service.Manager { return Create(s, func() service.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) }) } @@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { - m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) + m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager()) return &m }) } diff --git a/management/server/account.go b/management/server/account.go index 7d53cef03a0..93da07e207f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/shared/auth" @@ -39,9 +40,6 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -282,22 +280,14 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // User that performs the update has to belong to the account. // Returns an updated Settings func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - var oldSettings *types.Settings var updateAccountPeers bool var groupChangesAffectPeers bool var reloadReverseProxy bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var groupsUpdated bool + var err error oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { @@ -725,15 +715,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete) - if err != nil { - return fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") - } - userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users)) if err != nil { return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) @@ -976,6 +957,10 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s return nil, err } + if user.AccountID != accountID { + return nil, fmt.Errorf("user %s does not belong to account %s", userID, accountID) + } + key := user.IntegrationReference.CacheKey(accountID, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { @@ -1287,41 +1272,16 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin // GetAccountByID returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccount(ctx, accountID) } // GetAccountMeta returns the account metadata associated with this account ID. func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID) } // GetAccountOnboarding retrieves the onboarding information for a specific account. func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { log.Errorf("failed to get account onboarding for account %s: %v", accountID, err) @@ -1338,15 +1298,6 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou } func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { return nil, fmt.Errorf("failed to get account onboarding: %w", err) @@ -1405,9 +1356,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u return accountID, user.Id, nil } - if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { - return "", "", err - } + // Permission checks are now handled by the HTTP middleware via WithPermission wrapper + // User account association is already validated above by GetUserByUserID if !user.IsServiceUser && userAuth.Invited { err = am.redeemInvite(ctx, accountID, user.Id) @@ -1849,13 +1799,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction } func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) } @@ -2197,14 +2140,6 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee } func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) - if err != nil { - return fmt.Errorf("validate user permissions: %w", err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP) if err != nil { return fmt.Errorf("update peer IP transaction: %w", err) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b4516d51282..54d39cf7dd0 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -60,7 +60,7 @@ type Manager interface { GetUserByID(ctx context.Context, id string) (*types.User, error) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) - GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) + GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) @@ -75,7 +75,7 @@ type Manager interface { GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 36e5fe39f9f..501560c0102 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte } // GetGroupByName mocks base method. -func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { +func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID) + ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID) ret0, _ := ret[0].(*types.Group) ret1, _ := ret[1].(error) return ret0, ret1 } // GetGroupByName indicates an expected call of GetGroupByName. -func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID) } // GetIdentityProvider mocks base method. @@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo } // GetPeers mocks base method. -func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) { +func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter) + ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all) ret0, _ := ret[0].([]*peer.Peer) ret1, _ := ret[1].(error) return ret0, ret1 } // GetPeers indicates an expected call of GetPeers. -func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all) } // GetPolicy mocks base method. diff --git a/management/server/account_test.go b/management/server/account_test.go index bcc73d52f0c..37afae4ad17 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -22,9 +22,11 @@ import ( "go.opentelemetry.io/otel/metric/noop" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -49,7 +51,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -3147,7 +3148,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU AnyTimes() permissionsManager := permissions.NewManager(store) - peersManager := peers.NewManager(store, permissionsManager) + peersManager := peers.NewManager(store) proxyManager := proxy.NewMockManager(ctrl) proxyManager.EXPECT(). @@ -3164,7 +3165,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{}) manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err @@ -3175,7 +3176,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU if err != nil { return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil)) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil)) return manager, updateManager, nil } diff --git a/management/server/dns.go b/management/server/dns.go index baf6debc337..dd3cecb716b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -8,8 +8,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" @@ -22,14 +20,6 @@ const ( // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) } @@ -39,18 +29,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { return err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 0e37a3b22e5..3b303e6a9b0 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -14,11 +14,11 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -28,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -79,16 +78,6 @@ func TestGetDNSSettings(t *testing.T) { if len(dnsSettings.DisabledManagementGroups) != 1 { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - - _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) - if err == nil { - t.Errorf("An error should be returned when getting the DNS settings with a regular user") - } - - s, ok := status.FromError(err) - if !ok && s.Type() != status.PermissionDenied { - t.Errorf("returned error should be Permission Denied, got err: %s", err) - } } func TestSaveDNSSettings(t *testing.T) { @@ -223,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { // return empty extra settings for expected calls to UpdateAccountPeers settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) - peersManager := peers.NewManager(store, permissionsManager) + peersManager := peers.NewManager(store) ctx := context.Background() @@ -234,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } diff --git a/management/server/event.go b/management/server/event.go index d26c569ae92..96a0ba27916 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -9,11 +9,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/status" ) func isEnabled() bool { @@ -23,14 +20,6 @@ func isEnabled() bool { // GetEvents returns a list of activity events of an account func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true) if err != nil { return nil, err diff --git a/management/server/group.go b/management/server/group.go index 7b5b9b86c1a..06d21de2015 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -12,8 +12,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" @@ -32,13 +30,24 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read) + // Permission checks are now handled by the HTTP middleware via WithPermission wrapper + // This method is called from authenticated/authorized handlers, so we just validate + // that the user exists and is part of the account + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return err } - if !allowed { - return status.NewPermissionDeniedError() + if user == nil { + return status.NewUserNotFoundError(userID) + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsBlocked() { + return status.NewUserBlockedError() } return nil @@ -61,27 +70,17 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { - if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { - return nil, err - } +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) } // CreateGroup object of the peers func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var eventsToStore []func() var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } @@ -125,19 +124,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use // UpdateGroup object of the peers func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var eventsToStore []func() var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } @@ -196,33 +187,24 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use // It is the caller's responsibility to ensure proper locking is in place before invoking this method. // This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var eventsToStore []func() var updateAccountPeers bool var globalErr error groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - if err = transaction.CreateGroup(ctx, newGroup); err != nil { + if err := transaction.CreateGroup(ctx, newGroup); err != nil { return err } - err = transaction.IncrementNetworkSerial(ctx, accountID) - if err != nil { + if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } @@ -243,6 +225,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } } + var err error updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err @@ -264,21 +247,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us // It is the caller's responsibility to ensure proper locking is in place before invoking this method. // This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var eventsToStore []func() var updateAccountPeers bool var globalErr error groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } @@ -311,6 +287,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } } + var err error updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err @@ -416,14 +393,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var allErrors error var groupIDsToDelete []string var deletedGroups []*types.Group diff --git a/management/server/group_test.go b/management/server/group_test.go index fa818e53296..d34d24794e9 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -26,7 +26,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" peer2 "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -764,11 +763,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) { // Saving a group linked to network router should update account peers and send peer update t.Run("saving group linked to network router", func(t *testing.T) { - permissionsManager := permissions.NewManager(manager.Store) - groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) - resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager) - routersManager := routers.NewManager(manager.Store, permissionsManager, manager) - networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) + groupsManager := groups.NewManager(manager.Store, manager) + resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager) + routersManager := routers.NewManager(manager.Store, manager) + networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager) network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{ ID: "network_test", diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index d110ab564cb..40fe0355057 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -6,9 +6,6 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" @@ -25,31 +22,21 @@ type Manager interface { } type managerImpl struct { - store store.Store - permissionsManager permissions.Manager - accountManager account.Manager + store store.Store + accountManager account.Manager } type mockManager struct { } -func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager { +func NewManager(store store.Store, accountManager account.Manager) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - accountManager: accountManager, + store: store, + accountManager: accountManager, } } func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read) - if err != nil { - return nil, err - } - if !ok { - return nil, err - } - groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("error getting account groups: %w", err) @@ -73,14 +60,6 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str } func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) - if err != nil { - return err - } - if !ok { - return err - } - event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource) if err != nil { return fmt.Errorf("error adding resource to group: %w", err) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 56b2d820354..a2398ea58c6 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -10,6 +10,7 @@ import ( "github.com/rs/cors" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/server/types" @@ -31,10 +32,8 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/http/handlers/proxy" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" @@ -124,25 +123,25 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("failed to create instance manager: %w", err) } - accounts.AddEndpoints(accountManager, settingsManager, router) + accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager) peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager) - users.AddEndpoints(accountManager, router) - users.AddInvitesEndpoints(accountManager, router) + users.AddEndpoints(accountManager, router, permissionsManager) + users.AddInvitesEndpoints(accountManager, router, permissionsManager) users.AddPublicInvitesEndpoints(accountManager, router) - setup_keys.AddEndpoints(accountManager, router) - policies.AddEndpoints(accountManager, LocationManager, router) - policies.AddPostureCheckEndpoints(accountManager, LocationManager, router) + setup_keys.AddEndpoints(accountManager, router, permissionsManager) + policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager) + policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager) policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router) - groups.AddEndpoints(accountManager, router) - routes.AddEndpoints(accountManager, router) - dns.AddEndpoints(accountManager, router) - events.AddEndpoints(accountManager, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) - zonesManager.RegisterEndpoints(router, zManager) - recordsManager.RegisterEndpoints(router, rManager) - idp.AddEndpoints(accountManager, router) + groups.AddEndpoints(accountManager, router, permissionsManager) + routes.AddEndpoints(accountManager, router, permissionsManager) + dns.AddEndpoints(accountManager, router, permissionsManager) + events.AddEndpoints(accountManager, router, permissionsManager) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router) + zonesManager.RegisterEndpoints(router, zManager, permissionsManager) + recordsManager.RegisterEndpoints(router, rManager, permissionsManager) + idp.AddEndpoints(accountManager, router, permissionsManager) instance.AddEndpoints(instanceManager, router) - instance.AddVersionEndpoint(instanceManager, router) + instance.AddVersionEndpoint(instanceManager, router, permissionsManager) if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index cc5567e3db6..d4d68795ba5 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -12,10 +12,13 @@ import ( goversion "github.com/hashicorp/go-version" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -40,11 +43,11 @@ type handler struct { settingsManager settings.Manager } -func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) { accountsHandler := newHandler(accountManager, settingsManager) - router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") - router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") - router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler @@ -99,7 +102,7 @@ func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID st } func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error { - peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "") + peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "", true) if err != nil { return status.Errorf(status.Internal, "get peer count: %v", err) } @@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 { } // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID) + onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toAccountResponse(accountID, settings, meta, onboarding) + resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -233,24 +228,15 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS } // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - _, userID := userAuth.AccountId, userAuth.UserId - - vars := mux.Vars(r) - accountID := vars["accountId"] - if len(accountID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + accountID := mux.Vars(r)["accountId"] + if accountID != userAuth.AccountId { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w) return } var req api.PutApiAccountsAccountIdJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -267,7 +253,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w) return } - if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil { + if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil { util.WriteError(r.Context(), err, w) return } @@ -282,19 +268,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { } } - updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding) + updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding) if err != nil { util.WriteError(r.Context(), err, w) return } - updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -306,21 +292,14 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { } // deleteAccount is a HTTP DELETE handler to delete an account -func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - vars := mux.Vars(r) - targetAccountID := vars["accountId"] - if len(targetAccountID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w) +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + accountID := mux.Vars(r)["accountId"] + if accountID != userAuth.AccountId { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w) return } - err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId) + err := h.accountManager.DeleteAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 739dfe2f655..70d3d210127 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" @@ -290,8 +291,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", permissions.WrapHandler(handler.getAllAccounts)).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", permissions.WrapHandler(handler.updateAccount)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 67638aea5ad..bd1573c6cb1 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -5,11 +5,13 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -19,15 +21,15 @@ type dnsSettingsHandler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - addDNSSettingEndpoint(accountManager, router) - addDNSNameserversEndpoint(accountManager, router) +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { + addDNSSettingEndpoint(accountManager, router, permissionsManager) + addDNSNameserversEndpoint(accountManager, router, permissionsManager) } -func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) { +func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { dnsSettingsHandler := newDNSSettingsHandler(accountManager) - router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS") } // newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler @@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler { } // getDNSSettings returns the DNS settings for the account -func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque } // updateDNSSettings handles update to DNS settings of an account -func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PutApiDnsSettingsJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index a027c067e36..a97dfb91778 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -17,6 +17,7 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth" @@ -115,8 +116,8 @@ func TestDNSSettingsHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.getDNSSettings)).Methods("GET") + router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.updateDNSSettings)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go index bce1c4b7848..ff2a6ef551a 100644 --- a/management/server/http/handlers/dns/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -6,11 +6,13 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +23,13 @@ type nameserversHandler struct { accountManager account.Manager } -func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) { +func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { nameserversHandler := newNameserversHandler(accountManager) - router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") - router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") - router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") + router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS") } // newNameserversHandler returns a new instance of nameserversHandler handler @@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler { } // getAllNameservers returns the list of nameserver groups for the account -func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re } // createNameserverGroup handles nameserver group creation request -func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiDnsNameserversJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt } // updateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) @@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt } var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt } // deleteNameserverGroup handles nameserver group deletion request -func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) + err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt } // getNameserverGroup handles a nameserver group Get request identified by ID -func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index 4716782f3fa..d1844b65c5a 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth" @@ -201,10 +202,10 @@ func TestNameserversHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.getNameserverGroup)).Methods("GET") + router.HandleFunc("/api/dns/nameservers", permissions.WrapHandler(p.createNameserverGroup)).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.deleteNameserverGroup)).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.updateNameserverGroup)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/events/events_handler.go b/management/server/http/handlers/events/events_handler.go index ae1e64e5cf7..ee931f069c3 100644 --- a/management/server/http/handlers/events/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -5,11 +5,13 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -19,10 +21,10 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { eventsHandler := newHandler(accountManager) - router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") - router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") + router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS") + router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS") } // newHandler creates a new events handler @@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler { } // getAllEvents list of the given account -func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index 923a24e31e5..fa6127f859e 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth" @@ -196,7 +197,7 @@ func TestEvents_GetEvents(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") + router.HandleFunc("/api/events/", permissions.WrapHandler(handler.getAllEvents)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index f8d161a8783..e935d61afe7 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -7,11 +7,13 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -19,46 +21,45 @@ import ( // handler is a handler that returns groups of the account type handler struct { - accountManager account.Manager + accountManager account.Manager + permissionsManager permissions.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - groupsHandler := newHandler(accountManager) - router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") - router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") - router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") - router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") - router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { + groupsHandler := newHandler(accountManager, permissionsManager) + router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS") } // newHandler creates a new groups handler -func newHandler(accountManager account.Manager) *handler { +func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler { return &handler{ - accountManager: accountManager, + accountManager: accountManager, + permissionsManager: permissionsManager, } } -// getAllGroups list for the account -func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool { + allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read) + return err == nil && allowed +} +// getAllGroups list for the account +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { // Check if filtering by name groupName := r.URL.Query().Get("name") if groupName != "" { // Get single group by name - group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID) + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -71,13 +72,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } // Get all groups - groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -92,15 +93,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } // updateGroup handles update to a group identified by a given ID -func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) groupID, ok := vars["groupId"] if !ok { @@ -112,13 +105,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { return } - existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) + existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID) + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -166,13 +159,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil { - log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) + if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err) util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -182,17 +175,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { } // createGroup handles group creation request -func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiGroupsJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -226,13 +211,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { Issued: types.GroupIssuedAPI, } - err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group) + err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -242,22 +227,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { } // deleteGroup handles group deletion request -func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } - err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) + err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID) if err != nil { wrappedErr, ok := err.(interface{ Unwrap() []error }) if ok && len(wrappedErr.Unwrap()) > 0 { @@ -273,34 +250,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { } // getGroup returns a group -func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } - group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) + group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return } util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) - } func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index c7b4cbcdde2..9ed8ea5cd3a 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -13,10 +13,14 @@ import ( "strings" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" @@ -33,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*types.Group) *handler { +func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler { + t.Helper() + + ctrl := gomock.NewController(t) + permissionsManagerMock := permissions.NewMockManager(ctrl) + permissionsManagerMock.EXPECT(). + ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)). + Return(true, nil). + AnyTimes() + return &handler{ + permissionsManager: permissionsManagerMock, accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error { if !strings.HasPrefix(group.ID, "id-") { @@ -71,14 +85,14 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return groups, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } return nil, status.Errorf(status.NotFound, "unknown group name") }, - GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { + GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) { return maps.Values(TestPeers), nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { @@ -128,7 +142,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - p := initGroupTestData(group) + p := initGroupTestData(t, group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -141,7 +155,7 @@ func TestGetGroup(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.getGroup)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -254,7 +268,7 @@ func TestWriteGroup(t *testing.T) { }, } - p := initGroupTestData() + p := initGroupTestData(t) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -267,8 +281,8 @@ func TestWriteGroup(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.createGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") + router.HandleFunc("/api/groups", permissions.WrapHandler(p.createGroup)).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.updateGroup)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -332,7 +346,7 @@ func TestGetAllGroups(t *testing.T) { }, } - p := initGroupTestData() + p := initGroupTestData(t) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -345,7 +359,7 @@ func TestGetAllGroups(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET") + router.HandleFunc("/api/groups", permissions.WrapHandler(p.getAllGroups)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -414,7 +428,7 @@ func TestDeleteGroup(t *testing.T) { }, } - p := initGroupTestData() + p := initGroupTestData(t) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -426,7 +440,7 @@ func TestDeleteGroup(t *testing.T) { AccountId: "test_id", }) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.deleteGroup)).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/idp/idp_handler.go b/management/server/http/handlers/idp/idp_handler.go index 077507b898c..c478ed1f55c 100644 --- a/management/server/http/handlers/idp/idp_handler.go +++ b/management/server/http/handlers/idp/idp_handler.go @@ -6,9 +6,12 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -20,13 +23,13 @@ type handler struct { } // AddEndpoints registers identity provider endpoints -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { h := newHandler(accountManager) - router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS") - router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS") - router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS") - router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS") - router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS") + router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS") + router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS") } func newHandler(accountManager account.Manager) *handler { @@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler { } // getAllIdentityProviders returns all identity providers for the account -func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID) +func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request } // getIdentityProvider returns a specific identity provider -func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) idpID := vars["idpId"] if idpID == "" { @@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { return } - provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID) + provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { } // createIdentityProvider creates a new identity provider -func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.IdentityProviderRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) idp := fromAPIRequest(&req) - created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp) + created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp) if err != nil { util.WriteError(r.Context(), err, w) return @@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) } // updateIdentityProvider updates an existing identity provider -func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) idpID := vars["idpId"] if idpID == "" { @@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) idp := fromAPIRequest(&req) - updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp) + updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp) if err != nil { util.WriteError(r.Context(), err, w) return @@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) } // deleteIdentityProvider deletes an identity provider -func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) idpID := vars["idpId"] if idpID == "" { @@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) return } - if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil { + if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/handlers/idp/idp_handler_test.go b/management/server/http/handlers/idp/idp_handler_test.go index 74b20404812..f22ed2b202c 100644 --- a/management/server/http/handlers/idp/idp_handler_test.go +++ b/management/server/http/handlers/idp/idp_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" @@ -120,7 +121,7 @@ func TestGetAllIdentityProviders(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET") + router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.getAllIdentityProviders)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -180,7 +181,7 @@ func TestGetIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET") + router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.getIdentityProvider)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -242,7 +243,7 @@ func TestCreateIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST") + router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.createIdentityProvider)).Methods("POST") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -328,7 +329,7 @@ func TestUpdateIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT") + router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.updateIdentityProvider)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -388,7 +389,7 @@ func TestDeleteIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE") + router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.deleteIdentityProvider)).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index cd9fae6b83f..9e22b779667 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -7,7 +7,11 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { } // AddVersionEndpoint registers the authenticated version endpoint. -func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) { +func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) { h := &handler{ instanceManager: instanceManager, } - router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS") + router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS") } // getInstanceStatus returns the instance status including whether setup is required. @@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) { // getVersionInfo returns version information for NetBird components. // This endpoint requires authentication. -func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) { +func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { versionInfo, err := h.instanceManager.GetVersionInfo(r.Context()) if err != nil { log.WithContext(r.Context()).Errorf("failed to get version info: %v", err) diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 470079c85a1..645b2b0bc82 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -10,12 +10,17 @@ import ( "net/mail" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/idp" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -295,8 +300,15 @@ func TestSetup_ManagerError(t *testing.T) { func TestGetVersionInfo_Success(t *testing.T) { manager := &mockInstanceManager{} + ctrl := gomock.NewController(t) + permissionsManager := permissions.NewMockManager(ctrl) + permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, &auth.UserAuth{}) + } + }).AnyTimes() router := mux.NewRouter() - AddVersionEndpoint(manager, router) + AddVersionEndpoint(manager, router, permissionsManager) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) rec := httptest.NewRecorder() @@ -323,8 +335,15 @@ func TestGetVersionInfo_Error(t *testing.T) { return nil, errors.New("failed to fetch versions") }, } + ctrl := gomock.NewController(t) + permissionsManager := permissions.NewMockManager(ctrl) + permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, &auth.UserAuth{}) + } + }).AnyTimes() router := mux.NewRouter() - AddVersionEndpoint(manager, router) + AddVersionEndpoint(manager, router, permissionsManager) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) rec := httptest.NewRecorder() diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index f99eca7941f..ada3102ba00 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -9,8 +9,10 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" @@ -18,6 +20,7 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -33,16 +36,16 @@ type handler struct { groupsManager groups.Manager } -func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) { - addRouterEndpoints(routerManager, router) - addResourceEndpoints(resourceManager, groupsManager, router) +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) { + addRouterEndpoints(routerManager, permissionsManager, router) + addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router) networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager) - router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") - router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") - router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") - router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS") } func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler { @@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana } } -func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) + resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) + routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) } -func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { network := &types.Network{} network.FromAPIRequest(&req) - network.AccountID = accountID - network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) + network.AccountID = userAuth.AccountId + network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network) if err != nil { util.WriteError(r.Context(), err, w) return } - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) } -func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { return } - network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) + network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) } -func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { } var req api.NetworkRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { network.FromAPIRequest(&req) network.ID = networkID - network.AccountID = accountID - network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) + network.AccountID = userAuth.AccountId + network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network) if err != nil { util.WriteError(r.Context(), err, w) return } - routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) } -func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { return } - err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) + err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index c31729a39c7..e57bbb4206c 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -6,10 +6,13 @@ import ( "github.com/gorilla/mux" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -19,14 +22,14 @@ type resourceHandler struct { groupsManager groups.Manager } -func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) { +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) { resourceHandler := newResourceHandler(resourcesManager, groupsManager) - router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS") } func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler { @@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups. } } -func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] - resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) + resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, resourcesResponse) } -func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) +func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, resourcesResponse) } -func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkResourceRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) resource.FromAPIRequest(&req) resource.NetworkID = mux.Vars(r)["networkId"] - resource.AccountID = accountID - resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) + resource.AccountID = userAuth.AccountId + resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } -func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] - resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) + resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } -func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkResourceRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) resource.ID = mux.Vars(r)["resourceId"] resource.NetworkID = mux.Vars(r)["networkId"] - resource.AccountID = accountID - resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) + resource.AccountID = userAuth.AccountId + resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } -func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] - err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) + err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index ce9efb78d96..40a465fc37c 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -6,9 +6,12 @@ import ( "github.com/gorilla/mux" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -17,14 +20,14 @@ type routersHandler struct { routersManager routers.Manager } -func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) { +func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) { routersHandler := newRoutersHandler(routersManager) - router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS") } func newRoutersHandler(routersManager routers.Manager) *routersHandler { @@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler { } } -func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID) +func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routersResponse) } -func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] - routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) + routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, routersResponse) } -func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] var req api.NetworkRouterRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -103,7 +82,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.FromAPIRequest(&req) router.NetworkID = networkID - router.AccountID = accountID + router.AccountID = userAuth.AccountId router.Enabled = true if err := router.Validate(); err != nil { @@ -111,7 +90,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { return } - router, err = h.routersManager.CreateRouter(r.Context(), userID, router) + router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router) if err != nil { util.WriteError(r.Context(), err, w) return @@ -120,18 +99,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) } -func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] - router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) + router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -140,17 +111,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) } -func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkRouterRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -161,14 +124,14 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = mux.Vars(r)["networkId"] router.ID = mux.Vars(r)["routerId"] - router.AccountID = accountID + router.AccountID = userAuth.AccountId if err := router.Validate(); err != nil { util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) return } - router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) + router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router) if err != nil { util.WriteError(r.Context(), err, w) return @@ -177,17 +140,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) } -func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] - err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) + err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 6b9a69f04b9..2b7b697d619 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,7 +1,6 @@ package peers import ( - "context" "encoding/json" "fmt" "net/http" @@ -12,15 +11,15 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -35,14 +34,15 @@ type Handler struct { func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) { peersHandler := NewHandler(accountManager, networkMapController, permissionsManager) - router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") - router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") - router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS") - router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS") - router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS") + router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS") + router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS") + router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS") + router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS") } // NewHandler creates a new peers Handler @@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map } } -func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(ctx, err, w) - return - } - +func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] @@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) { job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil { - util.WriteError(ctx, err, w) + if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil { + util.WriteError(r.Context(), err, w) return } resp, err := toSingleJobResponse(job) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(ctx, w, resp) + util.WriteJSONObject(r.Context(), w, resp) } -func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(ctx, err, w) - return - } - +func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] - jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID) + jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } @@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) { for _, job := range jobs { resp, err := toSingleJobResponse(job) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } respBody = append(respBody, resp) } - util.WriteJSONObject(ctx, w, respBody) + util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(ctx, err, w) - return - } - +func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] jobID := vars["jobId"] - job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID) + job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } resp, err := toSingleJobResponse(job) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(ctx, w, resp) + util.WriteJSONObject(r.Context(), w, resp) } -func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) +// GetPeer handles GET request for a single peer +func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } if peer.ProxyMeta.Embedded { - util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w) return } - settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } dnsDomain := h.networkMapController.GetDNSDomain(settings) - grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { - log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) - util.WriteError(ctx, fmt.Errorf("internal error"), w) + log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] reason := invalidPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } -func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { +// UpdatePeer handles PUT request to update a peer +func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri } update := &nbpeer.Peer{ - ID: peerID, - SSHEnabled: req.SshEnabled, - Name: req.Name, - LoginExpirationEnabled: req.LoginExpirationEnabled, - + ID: peerID, + SSHEnabled: req.SshEnabled, + Name: req.Name, + LoginExpirationEnabled: req.LoginExpirationEnabled, InactivityExpirationEnabled: req.InactivityExpirationEnabled, } @@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri if req.Ip != nil { addr, err := netip.ParseAddr(*req.Ip) if err != nil { - util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w) return } - if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil { - util.WriteError(ctx, err, w) + if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil { + util.WriteError(r.Context(), err, w) return } } - peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) + peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } dnsDomain := h.networkMapController.GetDNSDomain(settings) - peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { - log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) - util.WriteError(ctx, fmt.Errorf("internal error"), w) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } @@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } -func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { - err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer: %v", err) - util.WriteError(ctx, err, w) - return - } - util.WriteJSONObject(ctx, w, util.EmptyObject{}) -} - -// HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +// DeletePeer handles DELETE request to delete a peer +func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { return } - switch r.Method { - case http.MethodDelete: - h.deletePeer(r.Context(), accountID, userID, peerID, w) - return - case http.MethodGet: - h.getPeer(r.Context(), accountID, peerID, userID, w) - return - case http.MethodPut: - h.updatePeer(r.Context(), accountID, userID, peerID, w, r) - return - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) - } -} - -// GetAllPeers returns a list of all peers associated with a provided account -func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) if err != nil { + log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err) util.WriteError(r.Context(), err, w) return } + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} +// GetAllPeers returns a list of all peers associated with a provided account +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nameFilter := r.URL.Query().Get("name") ipFilter := r.URL.Query().Get("ip") - accountID, userID := userAuth.AccountId, userAuth.UserId - - peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter) + peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter, true) if err != nil { util.WriteError(r.Context(), err, w) return } - settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator) + settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return } dnsDomain := h.networkMapController.GetDNSDomain(settings) - grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) @@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -372,25 +327,22 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - user, err := h.accountManager.GetUserByID(r.Context(), userID) + user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read) - if err != nil { - util.WriteError(r.Context(), status.NewPermissionValidationError(err), w) - return - } - - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) + account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return } - if !allowed && !userAuth.IsChild { + // Check if user is an admin/service user through their role + isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner + + if !isAdmin && !user.IsServiceUser && !userAuth.IsChild { if account.Settings.RegularUsersViewBlocked { util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{}) return @@ -408,7 +360,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -422,13 +374,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } -func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -437,7 +383,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) } var req api.PeerTemporaryAccessRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 6b36165978e..da3ef46fe1c 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -19,11 +19,11 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" @@ -174,7 +174,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { return nil, fmt.Errorf("user not found") } }, - GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { + GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) { return peers, nil }, GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { @@ -307,9 +307,9 @@ func TestGetPeers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") - router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET") - router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT") + router.HandleFunc("/api/peers/", permissions.WrapHandler(p.GetAllPeers)).Methods("GET") + router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.GetPeer)).Methods("GET") + router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -498,7 +498,7 @@ func TestGetAccessiblePeers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") + router.HandleFunc("/api/peers/{peerId}/accessible-peers", permissions.WrapHandler(p.GetAccessiblePeers)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -582,7 +582,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { rr := httptest.NewRecorder() router := mux.NewRouter() - router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT") + router.HandleFunc("/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT") router.ServeHTTP(rr, req) diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index 094a36e38f3..4ae54c42a31 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -14,12 +14,12 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" @@ -121,7 +121,7 @@ func TestGetCitiesByCountry(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", permissions.WrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -214,7 +214,7 @@ func TestGetAllCountries(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", permissions.WrapHandler(geolocationHandler.getAllCountries)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index a2d656a4716..dd9ea7c14b1 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -6,12 +6,12 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -30,8 +30,8 @@ type geolocationsHandler struct { func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) { locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager) - router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") - router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler @@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa } // getAllCountries retrieves a list of all countries -func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if l.geolocationManager == nil { // TODO: update error message to include geo db self hosted doc link when ready util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) @@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req } // getCitiesByCountry retrieves a list of cities based on the given country code -func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) countryCode := vars["country"] if !countryCodeRegex.MatchString(countryCode) { @@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *geolocationsHandler) authenticateUser(r *http.Request) error { - ctx := r.Context() - - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - return err - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return status.NewPermissionValidationError(err) - } - - if !allowed { - return status.NewPermissionDeniedError() - } - return nil -} - func toCountryResponse(country geolocation.Country) api.Country { return api.Country{ CountryName: country.CountryName, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index e4d1d73dfd8..a156df78e12 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -7,10 +7,13 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +24,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) { policiesHandler := newHandler(accountManager) - router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") - router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") - router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") - router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") - router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS") } // newHandler creates a new policies handler @@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler { } // getAllPolicies list for the account -func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { } // updatePolicy handles update to a policy identified by a given ID -func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + _, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, accountID, userID, policyID, false) + h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false) } // createPolicy handles policy creation request -func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - h.savePolicy(w, r, accountID, userID, "", true) +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true) } // savePolicy handles policy creation and update @@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s } // deletePolicy handles policy deletion request -func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { + if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil { util.WriteError(r.Context(), err, w) return } @@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { } // getPolicy handles a group Get request identified by ID -func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { return } - policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index ca5a0a6abfb..353ca456201 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" @@ -111,7 +112,7 @@ func TestPoliciesGetPolicy(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.getPolicy)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -275,8 +276,8 @@ func TestPoliciesWritePolicy(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", permissions.WrapHandler(p.createPolicy)).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.updatePolicy)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 744cde10b0e..c4b86288ec8 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -6,10 +6,13 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +24,13 @@ type postureChecksHandler struct { geolocationManager geolocation.Geolocation } -func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) { +func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) { postureCheckHandler := newPostureChecksHandler(accountManager, locationManager) - router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") - router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") - router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") - router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") - router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS") } // newPostureChecksHandler creates a new PostureChecks handler @@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager } // getAllPostureChecks list for the account -func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt } // updatePostureCheck handles update to a posture check identified by a given ID -func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http return } - _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + _, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, accountID, userID, postureChecksID, false) + p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false) } // createPostureCheck handles posture check creation request -func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - p.savePostureChecks(w, r, accountID, userID, "", true) +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true) } // getPostureCheck handles a posture check Get request identified by ID -func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re } // deletePostureCheck handles posture check deletion request -func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { + if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index a5999f6c7c2..27334c8afec 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/mock_server" @@ -183,7 +184,7 @@ func TestGetPostureCheck(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(p.getPostureCheck)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -841,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", permissions.WrapHandler(defaultHandler.createPostureCheck)).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 7bb6f2372e7..a6a8ba21086 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -8,9 +8,12 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -26,13 +29,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { routesHandler := newHandler(accountManager) - router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") - router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") - router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") - router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") - router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") + router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS") } // newHandler returns a new instance of routes handler @@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler { } // getAllRoutes returns the list of routes for the account -func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { } // createRoute handles route creation request -func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiRoutesJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { skipAutoApply = false } - newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply) + newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply) if err != nil { util.WriteError(r.Context(), err, w) @@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer * } // updateRoute handles update to a route identified by a given ID -func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) routeID := vars["routeId"] if len(routeID) == 0 { @@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) + _, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { newRoute.AccessControlGroups = *req.AccessControlGroups } - err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) + err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { } // deleteRoute handles route deletion request -func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) + err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { } // getRoute handles a route Get request identified by ID -func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) + foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index a44d81e3ec8..320b845a89f 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/util" @@ -501,10 +502,10 @@ func TestRoutesHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.createRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.getRoute)).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.deleteRoute)).Methods("DELETE") + router.HandleFunc("/api/routes", permissions.WrapHandler(p.createRoute)).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.updateRoute)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index d267b6eea2a..cb30893448e 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -8,9 +8,12 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +24,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { keysHandler := newHandler(accountManager) - router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS") } // newHandler creates a new setup key handler @@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler { } // createSetupKey is a POST requests that creates a new SetupKey -func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { req := &api.PostApiSetupKeysJSONRequestBody{} - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { allowExtraDNSLabels = *req.AllowExtraDnsLabels } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels) if err != nil { util.WriteError(r.Context(), err, w) return @@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { } // getSetupKey is a GET request to get a SetupKey by ID -func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { } // updateSetupKey is a PUT request to update server.SetupKey -func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { } req := &api.PutApiSetupKeysKeyIdJSONRequestBody{} - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { newKey.Revoked = req.Revoked newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { } // getAllSetupKeys is a GET request that returns a list of SetupKey -func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID) + err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index b137b6dd1e5..28595231c0a 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" @@ -171,11 +172,11 @@ func TestSetupKeysHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/users/invites_handler.go b/management/server/http/handlers/users/invites_handler.go index 0f0f57c2953..81585c28e35 100644 --- a/management/server/http/handlers/users/invites_handler.go +++ b/management/server/http/handlers/users/invites_handler.go @@ -9,10 +9,13 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -55,14 +58,14 @@ type invitesHandler struct { } // AddInvitesEndpoints registers invite-related endpoints -func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) { +func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { h := &invitesHandler{accountManager: accountManager} // Authenticated endpoints (require admin) - router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS") - router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS") - router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS") - router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS") + router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS") } // AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting @@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route } // listInvites handles GET /api/users/invites -func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { } // createInvite handles POST /api/users/invites -func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) { - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.UserInviteCreateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) { } // regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate -func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) { +func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - vars := mux.Vars(r) inviteID := vars["inviteId"] if inviteID == "" { @@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request } // deleteInvite handles DELETE /api/users/invites/{inviteId} -func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) inviteID := vars["inviteId"] if inviteID == "" { @@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) + err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go index 529ea24d649..5017c49b5cd 100644 --- a/management/server/http/handlers/users/invites_handler_test.go +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) { }) rr := httptest.NewRecorder() - handler.listInvites(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.listInvites(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) @@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) { }) rr := httptest.NewRecorder() - handler.createInvite(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.createInvite(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) @@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) { } rr := httptest.NewRecorder() - handler.regenerateInvite(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.regenerateInvite(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) @@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) { } rr := httptest.NewRecorder() - handler.deleteInvite(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.deleteInvite(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) }) diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 867db3ca9f7..3a978a4d5de 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -6,9 +6,12 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -19,12 +22,12 @@ type patHandler struct { accountManager account.Manager } -func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) { +func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { tokenHandler := newPATsHandler(accountManager) - router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") - router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") - router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS") } // newPATsHandler creates a new patHandler HTTP handler @@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler { } // getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] - if len(userID) == 0 { + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) + pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { } // getToken is HTTP GET handler that returns a personal access token for the given user -func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { } // createToken is HTTP POST handler that creates a personal access token for the given user -func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { } var req api.PostApiUsersUserIdTokensJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } - pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { } // deleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) + err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 7cda144686c..95275806fe2 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/management/internals/modules/permissions" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" @@ -181,10 +182,10 @@ func TestTokenHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.getAllTokens)).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.getToken)).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.createToken)).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.deleteToken)).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 40ad585d289..c79528c65c2 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -8,14 +8,16 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - - nbcontext "github.com/netbirdio/netbird/management/server/context" ) // handler is a handler that returns users of the account @@ -23,18 +25,18 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { userHandler := newHandler(accountManager) - router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") - router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS") - router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") - router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") - router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") - router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS") - addUsersTokensEndpoint(accountManager, router) + router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers, userHandler.getOwnUser)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser, userHandler.getCurrentUserFallback)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS") + addUsersTokensEndpoint(accountManager, router, permissionsManager) } // newHandler creates a new UsersHandler HTTP handler @@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler { } // updateUser is a PUT requests to update User data -func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -71,6 +66,11 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } + if existingUser.AccountID != userAuth.AccountId { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "user not found"), w) + return + } + req := &api.PutApiUsersUserIdJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -89,7 +89,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{ Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, @@ -102,23 +102,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId)) } // deleteUser is a DELETE request to delete a user -func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -126,7 +119,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) + err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -136,21 +129,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { } // createUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - req := &api.PostApiUsersJSONRequestBody{} - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -171,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -183,25 +169,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId)) } // getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -215,7 +194,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { continue } if serviceUser == "" { - users = append(users, toUserResponse(d, userID)) + users = append(users, toUserResponse(d, userAuth.UserId)) continue } @@ -226,7 +205,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { return } if includeServiceUser == d.IsServiceUser { - users = append(users, toUserResponse(d, userID)) + users = append(users, toUserResponse(d, userAuth.UserId)) } } @@ -235,19 +214,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { // inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -255,7 +227,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) + err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -264,19 +236,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth) + user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth) if err != nil { util.WriteError(r.Context(), err, w) return @@ -356,7 +322,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { } // approveUser is a POST request to approve a user that is pending approval -func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -369,11 +335,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) @@ -385,7 +346,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { } // rejectUser is a DELETE request to reject a user that is pending approval -func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -398,12 +359,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -421,7 +377,7 @@ type passwordChangeRequest struct { // changePassword is a PUT request to change user's password. // Only available when embedded IDP is enabled. // Users can only change their own password. -func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) { +func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -434,19 +390,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) { return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - var req passwordChangeRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } - err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword) + err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword) if err != nil { util.WriteError(r.Context(), err, w) return @@ -454,3 +404,39 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } + +func (h *handler) getCurrentUserFallback(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool { + s, ok := status.FromError(err) + if !ok || s.ErrorType != status.PermissionDenied { + return false + } + + user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth) + if userErr != nil { + util.WriteError(r.Context(), userErr, w) + return true + } + + util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId)) + return true +} + +func (h *handler) getOwnUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool { + s, ok := status.FromError(err) + if !ok || s.ErrorType != status.PermissionDenied { + return false + } + + if r.URL.Query().Get("service_user") != "" { + return false + } + + user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth) + if userErr != nil { + util.WriteError(r.Context(), userErr, w) + return true + } + + util.WriteJSONObject(r.Context(), w, []*api.User{toUserResponse(user.UserInfo, userAuth.UserId)}) + return true +} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index aa77dd8436b..9cb37a1c46b 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -15,10 +15,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/auth" @@ -38,6 +39,7 @@ var usersTestAccount = &types.Account{ Users: map[string]*types.User{ existingUserID: { Id: existingUserID, + AccountID: existingAccountID, Role: "admin", IsServiceUser: false, AutoGroups: []string{"group_1"}, @@ -45,6 +47,7 @@ var usersTestAccount = &types.Account{ }, regularUserID: { Id: regularUserID, + AccountID: existingAccountID, Role: "user", IsServiceUser: false, AutoGroups: []string{"group_1"}, @@ -52,6 +55,7 @@ var usersTestAccount = &types.Account{ }, serviceUserID: { Id: serviceUserID, + AccountID: existingAccountID, Role: "user", IsServiceUser: true, AutoGroups: []string{"group_1"}, @@ -59,6 +63,7 @@ var usersTestAccount = &types.Account{ }, nonDeletableServiceUserID: { Id: nonDeletableServiceUserID, + AccountID: existingAccountID, Role: "admin", IsServiceUser: true, NonDeletable: true, @@ -151,7 +156,7 @@ func initUsersTestData() *handler { NonDeletable: false, Issued: "api", }, - Permissions: mergeRolePermissions(roles.Owner), + Permissions: mergeRolePermissions(roles2.Owner), }, nil case "regular-user": return &users.UserInfoWithPermissions{ @@ -165,7 +170,7 @@ func initUsersTestData() *handler { NonDeletable: false, Issued: "api", }, - Permissions: mergeRolePermissions(roles.User), + Permissions: mergeRolePermissions(roles2.User), }, nil case "admin-user": @@ -181,7 +186,7 @@ func initUsersTestData() *handler { LastLogin: time.Time{}, Issued: "api", }, - Permissions: mergeRolePermissions(roles.Admin), + Permissions: mergeRolePermissions(roles2.Admin), }, nil case "restricted-user": return &users.UserInfoWithPermissions{ @@ -196,7 +201,7 @@ func initUsersTestData() *handler { LastLogin: time.Time{}, Issued: "api", }, - Permissions: mergeRolePermissions(roles.User), + Permissions: mergeRolePermissions(roles2.User), Restricted: true, }, nil } @@ -232,7 +237,11 @@ func TestGetUsers(t *testing.T) { AccountId: existingAccountID, }) - userHandler.getAllUsers(recorder, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.getAllUsers(recorder, req, userAuth) res := recorder.Result() defer res.Body.Close() @@ -343,7 +352,7 @@ func TestUpdateUser(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", permissions.WrapHandler(userHandler.updateUser)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -439,7 +448,11 @@ func TestCreateUser(t *testing.T) { AccountId: existingAccountID, }) - userHandler.createUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.createUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -490,7 +503,11 @@ func TestInviteUser(t *testing.T) { rr := httptest.NewRecorder() - userHandler.inviteUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.inviteUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -549,7 +566,11 @@ func TestDeleteUser(t *testing.T) { rr := httptest.NewRecorder() - userHandler.deleteUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.deleteUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -608,7 +629,7 @@ func TestCurrentUser(t *testing.T) { Issued: ptr("api"), LastLogin: ptr(time.Time{}), Permissions: &api.UserPermissions{ - Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)), + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Owner)), }, }, }, @@ -627,7 +648,7 @@ func TestCurrentUser(t *testing.T) { Issued: ptr("api"), LastLogin: ptr(time.Time{}), Permissions: &api.UserPermissions{ - Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)), }, }, }, @@ -646,7 +667,7 @@ func TestCurrentUser(t *testing.T) { Issued: ptr("api"), LastLogin: ptr(time.Time{}), Permissions: &api.UserPermissions{ - Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)), + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Admin)), }, }, }, @@ -666,7 +687,7 @@ func TestCurrentUser(t *testing.T) { LastLogin: ptr(time.Time{}), Permissions: &api.UserPermissions{ IsRestricted: true, - Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)), }, }, }, @@ -682,7 +703,11 @@ func TestCurrentUser(t *testing.T) { rr := httptest.NewRecorder() - userHandler.getCurrentUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: tc.requestAuth.UserId, + AccountId: existingAccountID, + } + userHandler.getCurrentUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -702,8 +727,8 @@ func ptr[T any, PT *T](x T) PT { return &x } -func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { - permissions := roles.Permissions{} +func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions { + permissions := roles2.Permissions{} for k := range modules.All { if rolePermissions, ok := role.Permissions[k]; ok { @@ -716,7 +741,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { return permissions } -func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool { +func stringifyPermissionsKeys(permissions roles2.Permissions) map[string]map[string]bool { modules := make(map[string]map[string]bool) for module, operations := range permissions { modules[string(module)] = make(map[string]bool) @@ -779,7 +804,7 @@ func TestApproveUserEndpoint(t *testing.T) { handler := newHandler(am) router := mux.NewRouter() - router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") + router.HandleFunc("/users/{userId}/approve", permissions.WrapHandler(handler.approveUser)).Methods("POST") req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) require.NoError(t, err) @@ -837,7 +862,7 @@ func TestRejectUserEndpoint(t *testing.T) { handler := newHandler(am) router := mux.NewRouter() - router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE") + router.HandleFunc("/users/{userId}/reject", permissions.WrapHandler(handler.rejectUser)).Methods("DELETE") req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) require.NoError(t, err) @@ -928,7 +953,7 @@ func TestChangePasswordEndpoint(t *testing.T) { handler := newHandler(am) router := mux.NewRouter() - router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT") + router.HandleFunc("/users/{userId}/password", permissions.WrapHandler(handler.changePassword)).Methods("PUT") reqPath := "/users/" + tc.targetUserID + "/password" req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody)) @@ -967,7 +992,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) { req = nbcontext.SetUserAuthInRequest(req, userAuth) rr := httptest.NewRecorder() - handler.changePassword(rr, req) + handler.changePassword(rr, req, &userAuth) assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) } diff --git a/management/server/http/testing/integration/accounts_handler_integration_test.go b/management/server/http/testing/integration/accounts_handler_integration_test.go index 511730ee54d..eae5915e3c3 100644 --- a/management/server/http/testing/integration/accounts_handler_integration_test.go +++ b/management/server/http/testing/integration/accounts_handler_integration_test.go @@ -27,7 +27,7 @@ func Test_Accounts_GetAll(t *testing.T) { {"Regular user", testing_tools.TestUserId, false}, {"Admin user", testing_tools.TestAdminId, true}, {"Owner user", testing_tools.TestOwnerId, true}, - {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, {"Admin service user", testing_tools.TestServiceAdminId, true}, {"Blocked user", testing_tools.BlockedUserId, false}, {"Other user", testing_tools.OtherUserId, false}, @@ -233,6 +233,71 @@ func Test_Accounts_Update(t *testing.T) { } } +func Test_Accounts_Update_CrossAccountAttack(t *testing.T) { + t.Run("Other user attempts to update testAccount via URL", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + body, err := json.Marshal(&api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: false, + PeerLoginExpiration: 86400, + }, + }) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + // OtherUserId belongs to otherAccountId, but we target testAccountId in URL + req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account update must be rejected") + }) +} + +func Test_Accounts_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, false}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, false}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Delete account", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + }) + } +} + +func Test_Accounts_Delete_CrossAccountAttack(t *testing.T) { + t.Run("Other user attempts to delete testAccount via URL", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + // OtherUserId belongs to otherAccountId, but we target testAccountId in URL + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account delete must be rejected") + }) +} + func stringPointer(s string) *string { return &s } diff --git a/management/server/http/testing/integration/dns_records_handler_integration_test.go b/management/server/http/testing/integration/dns_records_handler_integration_test.go new file mode 100644 index 00000000000..0f3d7c5fe76 --- /dev/null +++ b/management/server/http/testing/integration/dns_records_handler_integration_test.go @@ -0,0 +1,445 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Records_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all records", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones/testZoneId/records", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.DNSRecord{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "sub.example.com", got[0].Name) + assert.Equal(t, api.DNSRecordTypeA, got[0].Type) + assert.Equal(t, "1.2.3.4", got[0].Content) + assert.Equal(t, 300, got[0].Ttl) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Records_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + zoneId string + recordId string + expectedStatus int + expectRecord bool + }{ + { + name: "Get existing record", + zoneId: "testZoneId", + recordId: "testRecordId", + expectedStatus: http.StatusOK, + expectRecord: true, + }, + { + name: "Get non-existing record", + zoneId: "testZoneId", + recordId: "nonExistingRecordId", + expectedStatus: http.StatusNotFound, + expectRecord: false, + }, + { + name: "Get record from non-existing zone", + zoneId: "nonExistingZoneId", + recordId: "testRecordId", + expectedStatus: http.StatusNotFound, + expectRecord: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true) + + path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1) + path = strings.Replace(path, "{recordId}", tc.recordId, 1) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectRecord { + got := &api.DNSRecord{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "testRecordId", got.Id) + assert.Equal(t, "sub.example.com", got.Name) + assert.Equal(t, api.DNSRecordTypeA, got.Type) + assert.Equal(t, "1.2.3.4", got.Content) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Records_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + zoneId string + requestBody *api.PostApiDnsZonesZoneIdRecordsJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, record *api.DNSRecord) + }{ + { + name: "Create A record", + zoneId: "testZoneId", + requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{ + Name: "new.example.com", + Type: api.DNSRecordTypeA, + Content: "5.6.7.8", + Ttl: 600, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, record *api.DNSRecord) { + t.Helper() + assert.NotEmpty(t, record.Id) + assert.Equal(t, "new.example.com", record.Name) + assert.Equal(t, api.DNSRecordTypeA, record.Type) + assert.Equal(t, "5.6.7.8", record.Content) + assert.Equal(t, 600, record.Ttl) + }, + }, + { + name: "Create CNAME record", + zoneId: "testZoneId", + requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{ + Name: "alias.example.com", + Type: api.DNSRecordTypeCNAME, + Content: "target.example.com", + Ttl: 300, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, record *api.DNSRecord) { + t.Helper() + assert.NotEmpty(t, record.Id) + assert.Equal(t, "alias.example.com", record.Name) + assert.Equal(t, api.DNSRecordTypeCNAME, record.Type) + assert.Equal(t, "target.example.com", record.Content) + }, + }, + { + name: "Create record with invalid content for A type", + zoneId: "testZoneId", + requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{ + Name: "bad.example.com", + Type: api.DNSRecordTypeA, + Content: "not-an-ip", + Ttl: 300, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create record in non-existing zone", + zoneId: "nonExistingZoneId", + requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{ + Name: "new.example.com", + Type: api.DNSRecordTypeA, + Content: "5.6.7.8", + Ttl: 600, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + path := strings.Replace("/api/dns/zones/{zoneId}/records", "{zoneId}", tc.zoneId, 1) + req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.DNSRecord{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the created record directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbRecord := testing_tools.VerifyRecordInDB(t, db, got.Id) + assert.Equal(t, got.Name, dbRecord.Name) + assert.Equal(t, got.Content, dbRecord.Content) + assert.Equal(t, got.Ttl, dbRecord.TTL) + } + }) + } + } +} + +func Test_Records_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + zoneId string + recordId string + requestBody *api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, record *api.DNSRecord) + }{ + { + name: "Update record content and TTL", + zoneId: "testZoneId", + recordId: "testRecordId", + requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{ + Name: "sub.example.com", + Type: api.DNSRecordTypeA, + Content: "10.20.30.40", + Ttl: 600, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, record *api.DNSRecord) { + t.Helper() + assert.Equal(t, "sub.example.com", record.Name) + assert.Equal(t, "10.20.30.40", record.Content) + assert.Equal(t, 600, record.Ttl) + }, + }, + { + name: "Update non-existing record", + zoneId: "testZoneId", + recordId: "nonExistingRecordId", + requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{ + Name: "sub.example.com", + Type: api.DNSRecordTypeA, + Content: "10.20.30.40", + Ttl: 600, + }, + expectedStatus: http.StatusNotFound, + }, + { + name: "Update record in non-existing zone", + zoneId: "nonExistingZoneId", + recordId: "testRecordId", + requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{ + Name: "sub.example.com", + Type: api.DNSRecordTypeA, + Content: "10.20.30.40", + Ttl: 600, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1) + path = strings.Replace(path, "{recordId}", tc.recordId, 1) + req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.DNSRecord{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the updated record directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbRecord := testing_tools.VerifyRecordInDB(t, db, tc.recordId) + assert.Equal(t, "10.20.30.40", dbRecord.Content) + assert.Equal(t, 600, dbRecord.TTL) + } + }) + } + } +} + +func Test_Records_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + zoneId string + recordId string + expectedStatus int + }{ + { + name: "Delete existing record", + zoneId: "testZoneId", + recordId: "testRecordId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing record", + zoneId: "testZoneId", + recordId: "nonExistingRecordId", + expectedStatus: http.StatusNotFound, + }, + { + name: "Delete record from non-existing zone", + zoneId: "nonExistingZoneId", + recordId: "testRecordId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false) + + path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1) + path = strings.Replace(path, "{recordId}", tc.recordId, 1) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify deletion in DB for successful deletes by privileged users + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyRecordNotInDB(t, db, tc.recordId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/dns_zones_handler_integration_test.go b/management/server/http/testing/integration/dns_zones_handler_integration_test.go new file mode 100644 index 00000000000..14221c92c36 --- /dev/null +++ b/management/server/http/testing/integration/dns_zones_handler_integration_test.go @@ -0,0 +1,416 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Zones_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all zones", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Zone{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "Test Zone", got[0].Name) + assert.Equal(t, "example.com", got[0].Domain) + assert.Equal(t, true, got[0].Enabled) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Zones_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + zoneId string + expectedStatus int + expectZone bool + }{ + { + name: "Get existing zone", + zoneId: "testZoneId", + expectedStatus: http.StatusOK, + expectZone: true, + }, + { + name: "Get non-existing zone", + zoneId: "nonExistingZoneId", + expectedStatus: http.StatusNotFound, + expectZone: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectZone { + got := &api.Zone{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "testZoneId", got.Id) + assert.Equal(t, "Test Zone", got.Name) + assert.Equal(t, "example.com", got.Domain) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Zones_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + enabled := true + disabled := false + + tt := []struct { + name string + requestBody *api.PostApiDnsZonesJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, zone *api.Zone) + }{ + { + name: "Create zone with valid data", + requestBody: &api.PostApiDnsZonesJSONRequestBody{ + Name: "New Zone", + Domain: "newzone.com", + Enabled: &enabled, + EnableSearchDomain: false, + DistributionGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, zone *api.Zone) { + t.Helper() + assert.NotEmpty(t, zone.Id) + assert.Equal(t, "New Zone", zone.Name) + assert.Equal(t, "newzone.com", zone.Domain) + assert.Equal(t, true, zone.Enabled) + assert.Equal(t, false, zone.EnableSearchDomain) + assert.Equal(t, 1, len(zone.DistributionGroups)) + }, + }, + { + name: "Create zone with search domain enabled", + requestBody: &api.PostApiDnsZonesJSONRequestBody{ + Name: "Search Zone", + Domain: "search.example.com", + Enabled: &enabled, + EnableSearchDomain: true, + DistributionGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, zone *api.Zone) { + t.Helper() + assert.NotEmpty(t, zone.Id) + assert.Equal(t, "Search Zone", zone.Name) + assert.Equal(t, true, zone.EnableSearchDomain) + }, + }, + { + name: "Create disabled zone", + requestBody: &api.PostApiDnsZonesJSONRequestBody{ + Name: "Disabled Zone", + Domain: "disabled.example.com", + Enabled: &disabled, + EnableSearchDomain: false, + DistributionGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, zone *api.Zone) { + t.Helper() + assert.NotEmpty(t, zone.Id) + assert.Equal(t, false, zone.Enabled) + }, + }, + { + name: "Create zone with empty distribution groups", + requestBody: &api.PostApiDnsZonesJSONRequestBody{ + Name: "No Groups Zone", + Domain: "nogroups.com", + Enabled: &enabled, + EnableSearchDomain: false, + DistributionGroups: []string{}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/zones", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Zone{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the created zone directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbZone := testing_tools.VerifyZoneInDB(t, db, got.Id) + assert.Equal(t, got.Name, dbZone.Name) + assert.Equal(t, got.Domain, dbZone.Domain) + assert.Equal(t, got.Enabled, dbZone.Enabled) + assert.Equal(t, got.EnableSearchDomain, dbZone.EnableSearchDomain) + } + }) + } + } +} + +func Test_Zones_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + enabled := true + + tt := []struct { + name string + zoneId string + requestBody *api.PutApiDnsZonesZoneIdJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, zone *api.Zone) + }{ + { + name: "Update zone name and settings", + zoneId: "testZoneId", + requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{ + Name: "Updated Zone", + Domain: "example.com", + Enabled: &enabled, + EnableSearchDomain: true, + DistributionGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, zone *api.Zone) { + t.Helper() + assert.Equal(t, "Updated Zone", zone.Name) + assert.Equal(t, "example.com", zone.Domain) + assert.Equal(t, true, zone.EnableSearchDomain) + }, + }, + { + name: "Update non-existing zone", + zoneId: "nonExistingZoneId", + requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{ + Name: "Whatever", + Domain: "whatever.com", + Enabled: &enabled, + EnableSearchDomain: false, + DistributionGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Zone{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the updated zone directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbZone := testing_tools.VerifyZoneInDB(t, db, tc.zoneId) + assert.Equal(t, "Updated Zone", dbZone.Name) + assert.Equal(t, "example.com", dbZone.Domain) + assert.Equal(t, true, dbZone.Enabled) + assert.Equal(t, true, dbZone.EnableSearchDomain) + } + }) + } + } +} + +func Test_Zones_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + zoneId string + expectedStatus int + }{ + { + name: "Delete existing zone", + zoneId: "testZoneId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing zone", + zoneId: "nonExistingZoneId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify deletion in DB for successful deletes by privileged users + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyZoneNotInDB(t, db, tc.zoneId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/events_handler_integration_test.go b/management/server/http/testing/integration/events_handler_integration_test.go index 6611b60eefc..43242806f2c 100644 --- a/management/server/http/testing/integration/events_handler_integration_test.go +++ b/management/server/http/testing/integration/events_handler_integration_test.go @@ -78,6 +78,68 @@ func Test_Events_GetAll(t *testing.T) { } } +func Test_Events_GetAll_Audit(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all audit events", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false) + + // First, perform a mutation to generate an event (create a group as admin) + groupBody, err := json.Marshal(&api.GroupRequest{Name: "auditTestGroup"}) + if err != nil { + t.Fatalf("Failed to marshal group request: %v", err) + } + createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId) + createRecorder := httptest.NewRecorder() + apiHandler.ServeHTTP(createRecorder, createReq) + assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event") + + // Now query audit events + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/audit", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Event{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group") + + // Verify the group creation event exists + found := false + for _, event := range got { + if event.ActivityCode == "group.add" { + found = true + assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId) + assert.Equal(t, "Group created", event.Activity) + break + } + } + assert.True(t, found, "Expected to find a group.add event") + }) + } +} + func Test_Events_GetAll_Empty(t *testing.T) { apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true) diff --git a/management/server/http/testing/integration/geolocations_handler_integration_test.go b/management/server/http/testing/integration/geolocations_handler_integration_test.go new file mode 100644 index 00000000000..0164aa10cfd --- /dev/null +++ b/management/server/http/testing/integration/geolocations_handler_integration_test.go @@ -0,0 +1,118 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Geolocations_GetAllCountries(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all countries", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/locations/countries", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Country{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Geolocations_GetCitiesByCountry(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get cities by country", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "US", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.City{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Geolocations_GetCitiesByCountry_InvalidCode(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "INVALID", 1), testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, true) +} diff --git a/management/server/http/testing/integration/groups_handler_integration_test.go b/management/server/http/testing/integration/groups_handler_integration_test.go index edb43f3f320..d3932e3abbc 100644 --- a/management/server/http/testing/integration/groups_handler_integration_test.go +++ b/management/server/http/testing/integration/groups_handler_integration_test.go @@ -26,7 +26,7 @@ func Test_Groups_GetAll(t *testing.T) { {"Regular user", testing_tools.TestUserId, false}, {"Admin user", testing_tools.TestAdminId, true}, {"Owner user", testing_tools.TestOwnerId, true}, - {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, {"Admin service user", testing_tools.TestServiceAdminId, true}, {"Blocked user", testing_tools.BlockedUserId, false}, {"Other user", testing_tools.OtherUserId, false}, @@ -71,7 +71,7 @@ func Test_Groups_GetById(t *testing.T) { {"Regular user", testing_tools.TestUserId, false}, {"Admin user", testing_tools.TestAdminId, true}, {"Owner user", testing_tools.TestOwnerId, true}, - {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, {"Admin service user", testing_tools.TestServiceAdminId, true}, {"Blocked user", testing_tools.BlockedUserId, false}, {"Other user", testing_tools.OtherUserId, false}, @@ -216,7 +216,6 @@ func Test_Groups_Create(t *testing.T) { } tc.verifyResponse(t, got) - // Verify group exists in DB db := testing_tools.GetDB(t, am.GetStore()) dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id) assert.Equal(t, tc.requestBody.Name, dbGroup.Name) diff --git a/management/server/http/testing/integration/idp_handler_integration_test.go b/management/server/http/testing/integration/idp_handler_integration_test.go new file mode 100644 index 00000000000..e5b646685a7 --- /dev/null +++ b/management/server/http/testing/integration/idp_handler_integration_test.go @@ -0,0 +1,295 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_IdentityProviders_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all identity providers", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/identity-providers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.IdentityProvider{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + // The embedded IdP manager is not initialized in the test environment, + // so GetIdentityProviders returns an empty list. + assert.Equal(t, 0, len(got)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_IdentityProviders_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + idpId string + expectedStatus int + }{ + { + name: "Get existing identity provider", + idpId: "testIdpId", + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Get non-existing identity provider", + idpId: "nonExistingIdpId", + expectedStatus: http.StatusInternalServerError, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_IdentityProviders_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.PostApiIdentityProvidersJSONRequestBody + expectedStatus int + }{ + { + name: "Create identity provider with valid data", + requestBody: &api.PostApiIdentityProvidersJSONRequestBody{ + Type: api.IdentityProviderTypeGoogle, + Name: "New IDP", + ClientId: "newClientId", + ClientSecret: "newClientSecret", + }, + // Validation passes but the embedded IdP manager is not initialized, + // so the operation returns an internal server error. + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Create identity provider with invalid issuer", + requestBody: &api.PostApiIdentityProvidersJSONRequestBody{ + Type: api.IdentityProviderTypeOidc, + Name: "Invalid IDP", + Issuer: "not-a-url", + ClientId: "clientId", + ClientSecret: "clientSecret", + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/identity-providers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_IdentityProviders_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + idpId string + requestBody *api.PutApiIdentityProvidersIdpIdJSONRequestBody + expectedStatus int + }{ + { + name: "Update existing identity provider", + idpId: "testIdpId", + requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{ + Type: api.IdentityProviderTypeGoogle, + Name: "Updated IDP", + ClientId: "updatedClientId", + ClientSecret: "updatedClientSecret", + }, + // Validation passes but the embedded IdP manager is not initialized, + // so the operation returns an internal server error. + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Update non-existing identity provider", + idpId: "nonExistingIdpId", + requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{ + Type: api.IdentityProviderTypeGoogle, + Name: "Updated IDP", + ClientId: "updatedClientId", + ClientSecret: "updatedClientSecret", + }, + expectedStatus: http.StatusInternalServerError, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_IdentityProviders_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + idpId string + expectedStatus int + }{ + { + name: "Delete existing identity provider", + idpId: "testIdpId", + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Delete non-existing identity provider", + idpId: "nonExistingIdpId", + expectedStatus: http.StatusInternalServerError, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} diff --git a/management/server/http/testing/integration/instance_handler_integration_test.go b/management/server/http/testing/integration/instance_handler_integration_test.go new file mode 100644 index 00000000000..d82afdd6ef4 --- /dev/null +++ b/management/server/http/testing/integration/instance_handler_integration_test.go @@ -0,0 +1,183 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// Test_Instance_GetStatus tests the unauthenticated GET /api/instance endpoint. +// This endpoint bypasses auth middleware. With nil idpManager (no embedded IDP), +// SetupRequired should be false. +func Test_Instance_GetStatus(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + // The /api/instance endpoint is unauthenticated (bypass path). + // We still pass a token via BuildRequest but the bypass middleware skips auth. + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for instance status endpoint, got %d: %s", recorder.Code, string(content)) + + got := &api.InstanceStatus{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + // With nil idpManager (no embedded IDP configured), setup is not required. + assert.Equal(t, false, got.SetupRequired, "Expected SetupRequired to be false when embedded IDP is not configured") +} + +// Test_Instance_GetStatus_Unauthenticated verifies the endpoint works without any +// valid user token, since it is on the bypass path. +func Test_Instance_GetStatus_Unauthenticated(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + // Use an invalid token to confirm the bypass middleware skips auth + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.InvalidToken) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for unauthenticated instance status endpoint, got %d: %s", recorder.Code, string(content)) + + got := &api.InstanceStatus{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, false, got.SetupRequired) +} + +// Test_Instance_GetVersionInfo tests the authenticated GET /api/instance/version endpoint. +func Test_Instance_GetVersionInfo(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get version info", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance/version", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := &api.InstanceVersionInfo{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.NotEmpty(t, got.ManagementCurrentVersion, "Expected non-empty current version") + }) + } +} + +// Test_Instance_Setup tests the unauthenticated POST /api/setup endpoint. +// Since embedded IDP is not configured in the test environment, the setup +// endpoint should return an error (500 Internal Server Error) because the +// instance manager's CreateOwnerUser returns "embedded IDP is not enabled". +func Test_Instance_Setup(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + body, err := json.Marshal(&api.SetupRequest{ + Email: "admin@test.com", + Password: "securepassword123", + Name: "Admin User", + }) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + // The /api/setup endpoint is unauthenticated (bypass path). + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + // Without embedded IDP, CreateOwnerUser returns a plain error (not a status error), + // which the handler maps to 500 Internal Server Error. + assert.Equal(t, http.StatusInternalServerError, recorder.Code, + "Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content)) +} + +// Test_Instance_Setup_Unauthenticated verifies the setup endpoint works (reaches +// the handler) even with an invalid token, since it is on the bypass path. +func Test_Instance_Setup_Unauthenticated(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + body, err := json.Marshal(&api.SetupRequest{ + Email: "admin@test.com", + Password: "securepassword123", + Name: "Admin User", + }) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.InvalidToken) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + // Even with an invalid token, the bypass middleware skips auth and the handler runs. + // Without embedded IDP, it returns 500. + assert.Equal(t, http.StatusInternalServerError, recorder.Code, + "Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content)) +} diff --git a/management/server/http/testing/integration/invites_handler_integration_test.go b/management/server/http/testing/integration/invites_handler_integration_test.go new file mode 100644 index 00000000000..358fef667cc --- /dev/null +++ b/management/server/http/testing/integration/invites_handler_integration_test.go @@ -0,0 +1,154 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// Note: The integration test infrastructure does not configure an embedded IDP, +// so actual invite operations will return PreconditionFailed (412) for authorized users. +// These tests verify that the permissions layer correctly denies regular users +// before the handler logic is reached. + +func Test_Invites_List(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - List invites", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/invites", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} + +func Test_Invites_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Create invite", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + body, err := json.Marshal(&api.UserInviteCreateRequest{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + }) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users/invites", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} + +func Test_Invites_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Delete invite", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/invites/{inviteId}", "{inviteId}", "someInviteId", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} + +func Test_Invites_Regenerate(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Regenerate invite", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/invites/{inviteId}/regenerate", "{inviteId}", "someInviteId", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get PreconditionFailed (no embedded IDP configured) + // Unauthorized users get rejected by the permissions middleware + testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse) + }) + } +} diff --git a/management/server/http/testing/integration/peers_handler_integration_test.go b/management/server/http/testing/integration/peers_handler_integration_test.go index 17a9e94a67f..b06e6679aaf 100644 --- a/management/server/http/testing/integration/peers_handler_integration_test.go +++ b/management/server/http/testing/integration/peers_handler_integration_test.go @@ -45,7 +45,7 @@ func Test_Peers_GetAll(t *testing.T) { { name: "Regular service user", userId: testing_tools.TestServiceUserId, - expectResponse: true, + expectResponse: false, }, { name: "Admin service user", @@ -123,7 +123,7 @@ func Test_Peers_GetById(t *testing.T) { { name: "Regular service user", userId: testing_tools.TestServiceUserId, - expectResponse: true, + expectResponse: false, }, { name: "Admin service user", diff --git a/management/server/http/testing/integration/posture_checks_handler_integration_test.go b/management/server/http/testing/integration/posture_checks_handler_integration_test.go new file mode 100644 index 00000000000..40ac677bf2c --- /dev/null +++ b/management/server/http/testing/integration/posture_checks_handler_integration_test.go @@ -0,0 +1,372 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_PostureChecks_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all posture checks", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/posture-checks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.PostureCheck{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testPostureCheckId", got[0].Id) + assert.Equal(t, "NetBird Version Check", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_PostureChecks_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + postureCheckId string + expectedStatus int + expectCheck bool + }{ + { + name: "Get existing posture check", + postureCheckId: "testPostureCheckId", + expectedStatus: http.StatusOK, + expectCheck: true, + }, + { + name: "Get non-existing posture check", + postureCheckId: "nonExistingPostureCheckId", + expectedStatus: http.StatusNotFound, + expectCheck: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectCheck { + got := &api.PostureCheck{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "testPostureCheckId", got.Id) + assert.Equal(t, "NetBird Version Check", got.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PostureChecks_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + minVersion := "0.32.0" + tt := []struct { + name string + requestBody *api.PostureCheckUpdate + expectedStatus int + verifyResponse func(t *testing.T, check *api.PostureCheck) + }{ + { + name: "Create posture check with NB version", + requestBody: &api.PostureCheckUpdate{ + Name: "New Version Check", + Description: "check for new version", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: minVersion, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, check *api.PostureCheck) { + t.Helper() + assert.NotEmpty(t, check.Id) + assert.Equal(t, "New Version Check", check.Name) + assert.NotNil(t, check.Checks.NbVersionCheck) + assert.Equal(t, minVersion, check.Checks.NbVersionCheck.MinVersion) + }, + }, + { + name: "Create posture check with empty name", + requestBody: &api.PostureCheckUpdate{ + Name: "", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: "0.32.0", + }, + }, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/posture-checks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.PostureCheck{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + db := testing_tools.GetDB(t, am.GetStore()) + dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, got.Id) + assert.Equal(t, got.Name, dbCheck.Name) + } + }) + } + } +} + +func Test_PostureChecks_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + minVersion := "0.33.0" + tt := []struct { + name string + postureCheckId string + requestBody *api.PostureCheckUpdate + expectedStatus int + verifyResponse func(t *testing.T, check *api.PostureCheck) + }{ + { + name: "Update posture check name and version", + postureCheckId: "testPostureCheckId", + requestBody: &api.PostureCheckUpdate{ + Name: "Updated Version Check", + Description: "updated description", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: minVersion, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, check *api.PostureCheck) { + t.Helper() + assert.Equal(t, "testPostureCheckId", check.Id) + assert.Equal(t, "Updated Version Check", check.Name) + }, + }, + { + name: "Update non-existing posture check", + postureCheckId: "nonExistingPostureCheckId", + requestBody: &api.PostureCheckUpdate{ + Name: "whatever", + Checks: &api.Checks{ + NbVersionCheck: &api.NBVersionCheck{ + MinVersion: "0.33.0", + }, + }, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.PostureCheck{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + db := testing_tools.GetDB(t, am.GetStore()) + dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, tc.postureCheckId) + assert.Equal(t, "Updated Version Check", dbCheck.Name) + } + }) + } + } +} + +func Test_PostureChecks_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + postureCheckId string + expectedStatus int + }{ + { + name: "Delete existing posture check", + postureCheckId: "testPostureCheckId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing posture check", + postureCheckId: "nonExistingPostureCheckId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPostureCheckNotInDB(t, db, tc.postureCheckId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/reverse_proxy_handler_integration_test.go b/management/server/http/testing/integration/reverse_proxy_handler_integration_test.go new file mode 100644 index 00000000000..84b32826c60 --- /dev/null +++ b/management/server/http/testing/integration/reverse_proxy_handler_integration_test.go @@ -0,0 +1,306 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_ReverseProxy_GetClusters(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get clusters", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/clusters", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.ProxyCluster{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got), "Expected empty clusters list when no proxy is connected") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_ReverseProxy_GetAllServices(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all services", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/services", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.Service{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got), "Expected empty services list with no services in DB") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_ReverseProxy_CreateService(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Create service", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + // Creating a service requires a valid domain entry. Without one, authorized + // users will receive a validation error. This test verifies that the + // permissions layer correctly rejects unauthorized users before handler + // logic runs, and that authorized users reach the handler (even if the + // handler returns an error due to missing domain). + body, err := json.Marshal(&api.ServiceRequest{ + Name: "test-service", + Domain: "nonexistent.example.com", + Enabled: true, + }) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/reverse-proxies/services", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users reach the handler but get a validation error (no valid domain). + // Unauthorized users are rejected by the permissions middleware. + testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, user.expectResponse) + }) + } +} + +func Test_ReverseProxy_GetServiceById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get service by ID (non-existing)", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get NotFound for a non-existing service. + // Unauthorized users are rejected by the permissions middleware. + testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse) + }) + } +} + +func Test_ReverseProxy_DeleteService(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Delete service (non-existing)", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + // Authorized users get NotFound for a non-existing service. + // Unauthorized users are rejected by the permissions middleware. + testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse) + }) + } +} + +func Test_ReverseProxy_GetAllDomains(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all domains", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/domains", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.ReverseProxyDomain{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got), "Expected empty domains list with no domains in DB") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_ReverseProxy_GetAccessLogs(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get proxy access logs", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/proxy", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := &api.ProxyAccessLogsResponse{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got.Data), "Expected empty access logs data") + assert.Equal(t, 0, got.TotalRecords, "Expected zero total records") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} diff --git a/management/server/http/testing/integration/users_approve_reject_handler_integration_test.go b/management/server/http/testing/integration/users_approve_reject_handler_integration_test.go new file mode 100644 index 00000000000..52bf8983d64 --- /dev/null +++ b/management/server/http/testing/integration/users_approve_reject_handler_integration_test.go @@ -0,0 +1,124 @@ +//go:build integration + +package integration + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" +) + +func Test_Users_Approve(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + expectedStatus int + }{ + { + name: "Approve pending user", + targetUserId: "pendingUserId", + expectedStatus: http.StatusOK, + }, + { + name: "Approve non-existing user", + targetUserId: "nonExistingUserId", + expectedStatus: http.StatusNotFound, + }, + { + name: "Approve already active user", + targetUserId: testing_tools.TestUserId, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/{userId}/approve", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_Users_Reject(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + expectedStatus int + }{ + { + name: "Reject pending user", + targetUserId: "pendingUserId", + expectedStatus: http.StatusOK, + }, + { + name: "Reject non-existing user", + targetUserId: "nonExistingUserId", + expectedStatus: http.StatusNotFound, + }, + { + name: "Reject non-pending user", + targetUserId: testing_tools.TestUserId, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}/reject", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/users_current_handler_integration_test.go b/management/server/http/testing/integration/users_current_handler_integration_test.go new file mode 100644 index 00000000000..3f32dc2c7b1 --- /dev/null +++ b/management/server/http/testing/integration/users_current_handler_integration_test.go @@ -0,0 +1,64 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Users_GetCurrent(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, true}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, false}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get current user", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/current", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := &api.User{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, user.userId, got.Id) + assert.NotNil(t, got.IsCurrent) + assert.True(t, *got.IsCurrent) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} diff --git a/management/server/http/testing/integration/users_handler_integration_test.go b/management/server/http/testing/integration/users_handler_integration_test.go index eae3b4ad5a0..a1c9e48d924 100644 --- a/management/server/http/testing/integration/users_handler_integration_test.go +++ b/management/server/http/testing/integration/users_handler_integration_test.go @@ -26,7 +26,7 @@ func Test_Users_GetAll(t *testing.T) { {"Regular user", testing_tools.TestUserId, true}, {"Admin user", testing_tools.TestAdminId, true}, {"Owner user", testing_tools.TestOwnerId, true}, - {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, {"Admin service user", testing_tools.TestServiceAdminId, true}, {"Blocked user", testing_tools.BlockedUserId, false}, {"Other user", testing_tools.OtherUserId, false}, @@ -637,6 +637,38 @@ func Test_PATs_Create(t *testing.T) { } } +func Test_Users_Update_CrossAccountAttack(t *testing.T) { + t.Run("Admin attempts to update user from other account", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + body, _ := json.Marshal(&api.UserRequest{ + Role: "user", + AutoGroups: []string{}, + IsBlocked: true, + }) + + // TestAdminId belongs to testAccountId, but targets otherUserId which belongs to otherAccountId + req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/users/otherUserId", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user update must be rejected") + }) +} + +func Test_Users_Delete_CrossAccountAttack(t *testing.T) { + t.Run("Admin attempts to delete service user from other account", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + // TestAdminId belongs to testAccountId, but targets otherServiceUserId which belongs to otherAccountId + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/users/otherServiceUserId", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user delete must be rejected") + }) +} + func Test_PATs_Delete(t *testing.T) { users := []struct { name string diff --git a/management/server/http/testing/testdata/accounts.sql b/management/server/http/testing/testdata/accounts.sql index 35f00d41977..bfffd13fead 100644 --- a/management/server/http/testing/testdata/accounts.sql +++ b/management/server/http/testing/testdata/accounts.sql @@ -15,4 +15,4 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); diff --git a/management/server/http/testing/testdata/dns.sql b/management/server/http/testing/testdata/dns.sql index 9ed4daf7ed7..69d85da7ca1 100644 --- a/management/server/http/testing/testdata/dns.sql +++ b/management/server/http/testing/testdata/dns.sql @@ -15,7 +15,7 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/dns_zones.sql b/management/server/http/testing/testdata/dns_zones.sql new file mode 100644 index 00000000000..acc182ea621 --- /dev/null +++ b/management/server/http/testing/testdata/dns_zones.sql @@ -0,0 +1,26 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `zones` (`id` text,`account_id` text,`name` text,`domain` text,`enabled` numeric,`enable_search_domain` numeric,`distribution_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_zones` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `records` (`id` text,`account_id` text,`zone_id` text,`name` text,`type` text,`content` text,`ttl` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_zones_records` FOREIGN KEY (`zone_id`) REFERENCES `zones`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0); + +INSERT INTO zones VALUES('testZoneId','testAccountId','Test Zone','example.com',1,0,'["testGroupId"]'); +INSERT INTO records VALUES('testRecordId','testAccountId','testZoneId','sub.example.com','A','1.2.3.4',300); \ No newline at end of file diff --git a/management/server/http/testing/testdata/events.sql b/management/server/http/testing/testdata/events.sql index 27fd01aea81..8dfe44faa63 100644 --- a/management/server/http/testing/testdata/events.sql +++ b/management/server/http/testing/testdata/events.sql @@ -14,5 +14,5 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/groups.sql b/management/server/http/testing/testdata/groups.sql index eb874f0366f..4f83a080657 100644 --- a/management/server/http/testing/testdata/groups.sql +++ b/management/server/http/testing/testdata/groups.sql @@ -15,5 +15,5 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/identity_providers.sql b/management/server/http/testing/testdata/identity_providers.sql new file mode 100644 index 00000000000..9dc04241829 --- /dev/null +++ b/management/server/http/testing/testdata/identity_providers.sql @@ -0,0 +1,20 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `identity_providers` (`id` text,`account_id` text,`type` text,`name` text,`issuer` text,`client_id` text,`client_secret` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_identity_providers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO identity_providers VALUES('testIdpId','testAccountId','oidc','Test IDP','https://issuer.example.com','client123','secret456'); \ No newline at end of file diff --git a/management/server/http/testing/testdata/peers.sql b/management/server/http/testing/testdata/peers.sql index 863eda5205d..3593222a7ed 100644 --- a/management/server/http/testing/testdata/peers.sql +++ b/management/server/http/testing/testdata/peers.sql @@ -13,7 +13,7 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); diff --git a/management/server/http/testing/testdata/peers_integration.sql b/management/server/http/testing/testdata/peers_integration.sql index 62a7760e7fa..eb6094f1f91 100644 --- a/management/server/http/testing/testdata/peers_integration.sql +++ b/management/server/http/testing/testdata/peers_integration.sql @@ -14,7 +14,7 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId","testPeerId2"]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','test-host-1','linux','Linux','','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-1','test-peer-1','2023-03-02 09:21:02.189035775+01:00',0,0,0,'testUserId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('testPeerId2','testAccountId','6rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYBg=','82546A29-6BC8-4311-BCFC-9CDBF33F1A49','"100.64.114.32"','test-host-2','linux','Linux','','unknown','Ubuntu','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-2','test-peer-2','2023-03-02 09:21:02.189035775+01:00',1,0,0,'testAdminId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/policies.sql b/management/server/http/testing/testdata/policies.sql index 7e6cc883b6e..7374112fef5 100644 --- a/management/server/http/testing/testdata/policies.sql +++ b/management/server/http/testing/testdata/policies.sql @@ -16,7 +16,7 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO policies VALUES('testPolicyId','testAccountId','testPolicy','test policy description',1,NULL); diff --git a/management/server/http/testing/testdata/posture_checks.sql b/management/server/http/testing/testdata/posture_checks.sql new file mode 100644 index 00000000000..4b2f11f6c9c --- /dev/null +++ b/management/server/http/testing/testdata/posture_checks.sql @@ -0,0 +1,21 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO posture_checks VALUES('testPostureCheckId','NetBird Version Check','Require minimum NetBird version','testAccountId','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); \ No newline at end of file diff --git a/management/server/http/testing/testdata/routes.sql b/management/server/http/testing/testdata/routes.sql index 48aa02052a7..c8d3b880ea2 100644 --- a/management/server/http/testing/testdata/routes.sql +++ b/management/server/http/testing/testdata/routes.sql @@ -16,7 +16,7 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('peerGroupId','testAccountId','peerGroupName','api','["testPeerId"]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO routes VALUES('testRouteId','testAccountId','"10.0.0.0/24"',NULL,0,'testNet','Test Network Route','testPeerId',NULL,1,1,100,1,'["testGroupId"]',NULL,0); diff --git a/management/server/http/testing/testdata/setup_keys.sql b/management/server/http/testing/testdata/setup_keys.sql index 6d30fb5fef8..77bec776118 100644 --- a/management/server/http/testing/testdata/setup_keys.sql +++ b/management/server/http/testing/testdata/setup_keys.sql @@ -18,7 +18,7 @@ INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[ CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime DEFAULT NULL,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); diff --git a/management/server/http/testing/testdata/users.sql b/management/server/http/testing/testdata/users.sql index 346f7b7ac5b..f1c2305cbd8 100644 --- a/management/server/http/testing/testdata/users.sql +++ b/management/server/http/testing/testdata/users.sql @@ -6,7 +6,7 @@ CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`i INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); diff --git a/management/server/http/testing/testdata/users_approve_reject.sql b/management/server/http/testing/testdata/users_approve_reject.sql new file mode 100644 index 00000000000..4c7a306aad7 --- /dev/null +++ b/management/server/http/testing/testdata/users_approve_reject.sql @@ -0,0 +1,19 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`pending_approval` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('pendingUserId','testAccountId','user',0,0,'','[]',1,1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/users_integration.sql b/management/server/http/testing/testdata/users_integration.sql index 57df73e8c0c..90ce450e3ac 100644 --- a/management/server/http/testing/testdata/users_integration.sql +++ b/management/server/http/testing/testdata/users_integration.sql @@ -15,9 +15,10 @@ INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testS INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO users VALUES('deletableServiceUserId','testAccountId','user',1,0,'deletableServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherServiceUserId','otherAccountId','user',1,0,'otherServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO personal_access_tokens VALUES('testTokenId','testUserId','testToken','hashedTokenValue123','2325-10-02 16:01:38.000000000+00:00','testUserId','2024-10-02 16:01:38.000000000+00:00',NULL); diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 1a8b83c7eed..273ad51d49c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -12,7 +12,7 @@ import ( "go.opentelemetry.io/otel/metric/noop" "github.com/netbirdio/management-integrations/integrations" - + "github.com/netbirdio/netbird/management/internals/modules/permissions" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -43,7 +43,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -82,8 +81,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee proxyController := integrations.NewController(store) userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) - peersManager := peers.NewManager(store, permissionsManager) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), settings.IdpConfig{}) + peersManager := peers.NewManager(store) jobManager := job.NewJobManager(nil, store, peersManager) @@ -101,7 +100,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create manager: %v", err) } - accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) + accessLogsManager := accesslogsmanager.NewManager(store, nil) proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") @@ -110,12 +109,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create proxy manager: %v", err) } proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) - domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) + domainManager := manager.NewManager(store, proxyMgr, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) @@ -128,14 +127,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee GetPATInfoFunc: authManager.GetPATInfo, } - groupsManager := groups.NewManager(store, permissionsManager, am) - routersManager := routers.NewManager(store, permissionsManager, am) - resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager) - networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am) - customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") - zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) + groupsManager := groups.NewManager(store, am) + routersManager := routers.NewManager(store, am) + resourcesManager := resources.NewManager(store, groupsManager, am, serviceManager) + networksManager := networks.NewManager(store, resourcesManager, routersManager, am) + customZonesManager := zonesManager.NewManager(store, am, "") + zoneRecordsManager := recordsManager.NewManager(store, am) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, &domainManager, accessLogsManager, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -211,8 +210,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin proxyController := integrations.NewController(store) userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) - peersManager := peers.NewManager(store, permissionsManager) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), settings.IdpConfig{}) + peersManager := peers.NewManager(store) jobManager := job.NewJobManager(nil, store, peersManager) @@ -230,7 +229,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin t.Fatalf("Failed to create manager: %v", err) } - accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) + accessLogsManager := accesslogsmanager.NewManager(store, nil) proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") @@ -239,12 +238,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin t.Fatalf("Failed to create proxy manager: %v", err) } proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) - domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) + domainManager := manager.NewManager(store, proxyMgr, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) @@ -257,14 +256,14 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin GetPATInfoFunc: authManager.GetPATInfo, } - groupsManager := groups.NewManager(store, permissionsManager, am) - routersManager := routers.NewManager(store, permissionsManager, am) - resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager) - networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am) - customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") - zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) + groupsManager := groups.NewManager(store, am) + routersManager := routers.NewManager(store, am) + resourcesManager := resources.NewManager(store, groupsManager, am, serviceManager) + networksManager := networks.NewManager(store, resourcesManager, routersManager, am) + customZonesManager := zonesManager.NewManager(store, am, "") + zoneRecordsManager := recordsManager.NewManager(store, am) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, &domainManager, accessLogsManager, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/http/testing/testing_tools/db_verify.go b/management/server/http/testing/testing_tools/db_verify.go index f8af6a41f15..18a00f3d792 100644 --- a/management/server/http/testing/testing_tools/db_verify.go +++ b/management/server/http/testing/testing_tools/db_verify.go @@ -8,10 +8,13 @@ import ( "gorm.io/gorm" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" @@ -220,3 +223,54 @@ func VerifyNetworkRouterNotInDB(t *testing.T, db *gorm.DB, routerID string) { db.Model(&routerTypes.NetworkRouter{}).Where("id = ? AND account_id = ?", routerID, TestAccountId).Count(&count) assert.Equal(t, int64(0), count, "Expected network router %s to NOT exist in DB", routerID) } + +// VerifyPostureCheckInDB reads a posture check directly from the DB and returns it. +func VerifyPostureCheckInDB(t *testing.T, db *gorm.DB, postureCheckID string) *posture.Checks { + t.Helper() + var check posture.Checks + err := db.Where("id = ? AND account_id = ?", postureCheckID, TestAccountId).First(&check).Error + require.NoError(t, err, "Expected posture check %s to exist in DB", postureCheckID) + return &check +} + +// VerifyPostureCheckNotInDB verifies that a posture check does not exist in the DB. +func VerifyPostureCheckNotInDB(t *testing.T, db *gorm.DB, postureCheckID string) { + t.Helper() + var count int64 + db.Model(&posture.Checks{}).Where("id = ? AND account_id = ?", postureCheckID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected posture check %s to NOT exist in DB", postureCheckID) +} + +// VerifyZoneInDB reads a zone directly from the DB and returns it. +func VerifyZoneInDB(t *testing.T, db *gorm.DB, zoneID string) *zones.Zone { + t.Helper() + var zone zones.Zone + err := db.Where("id = ? AND account_id = ?", zoneID, TestAccountId).First(&zone).Error + require.NoError(t, err, "Expected zone %s to exist in DB", zoneID) + return &zone +} + +// VerifyZoneNotInDB verifies that a zone does not exist in the DB. +func VerifyZoneNotInDB(t *testing.T, db *gorm.DB, zoneID string) { + t.Helper() + var count int64 + db.Model(&zones.Zone{}).Where("id = ? AND account_id = ?", zoneID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected zone %s to NOT exist in DB", zoneID) +} + +// VerifyRecordInDB reads a record directly from the DB and returns it. +func VerifyRecordInDB(t *testing.T, db *gorm.DB, recordID string) *records.Record { + t.Helper() + var record records.Record + err := db.Where("id = ? AND account_id = ?", recordID, TestAccountId).First(&record).Error + require.NoError(t, err, "Expected record %s to exist in DB", recordID) + return &record +} + +// VerifyRecordNotInDB verifies that a record does not exist in the DB. +func VerifyRecordNotInDB(t *testing.T, db *gorm.DB, recordID string) { + t.Helper() + var count int64 + db.Model(&records.Record{}).Where("id = ? AND account_id = ?", recordID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected record %s to NOT exist in DB", recordID) +} diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index b7a63b104c2..e3b187622a6 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -106,6 +106,10 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta } if !expectResponse { + if recorder.Code == http.StatusOK || recorder.Code == http.StatusCreated { + t.Fatalf("expected unauthorized/error status code but got %d, content: %s", + recorder.Code, string(content)) + } return nil, false } diff --git a/management/server/identity_provider.go b/management/server/identity_provider.go index 8fd96c23859..0533ed7c580 100644 --- a/management/server/identity_provider.go +++ b/management/server/identity_provider.go @@ -17,8 +17,6 @@ import ( "github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -88,14 +86,6 @@ func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.Identi // GetIdentityProviders returns all identity providers for an account func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) { - ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) if !ok { log.Warn("identity provider management requires embedded IdP") @@ -117,14 +107,6 @@ func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accou // GetIdentityProvider returns a specific identity provider by ID func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) { - ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) if !ok { return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP") @@ -143,14 +125,6 @@ func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accoun // CreateIdentityProvider creates a new identity provider func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) { - ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil { return nil, err } @@ -168,7 +142,7 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc connCfg := identityProviderToConnectorConfig(idpConfig) - _, err = embeddedManager.CreateConnector(ctx, connCfg) + _, err := embeddedManager.CreateConnector(ctx, connCfg) if err != nil { return nil, status.Errorf(status.Internal, "failed to create identity provider: %v", err) } @@ -180,14 +154,6 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc // UpdateIdentityProvider updates an existing identity provider func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) { - ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil { return nil, err } @@ -213,14 +179,6 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc // DeleteIdentityProvider deletes an identity provider func (am *DefaultAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error { - ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) if !ok { return status.Errorf(status.Internal, "identity provider management requires embedded IdP") diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go index d51254c557f..cc27670b9d9 100644 --- a/management/server/identity_provider_test.go +++ b/management/server/identity_provider_test.go @@ -18,13 +18,13 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -83,7 +83,7 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update AnyTimes() permissionsManager := permissions.NewManager(testStore) - peersManager := peers.NewManager(testStore, permissionsManager) + peersManager := peers.NewManager(testStore) cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { diff --git a/management/server/instance/manager_test.go b/management/server/instance/manager_test.go index e3be9cfead7..6c1c92fdf7e 100644 --- a/management/server/instance/manager_test.go +++ b/management/server/instance/manager_test.go @@ -17,9 +17,9 @@ import ( ) type mockIdP struct { - mu sync.Mutex - createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) - users map[string][]*idp.UserData + mu sync.Mutex + createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) + users map[string][]*idp.UserData getAllAccountsErr error } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 18d85315d39..47f73cd0293 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" @@ -34,7 +35,6 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -363,12 +363,12 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config AnyTimes() permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() - peersManager := peers.NewManager(store, permissionsManager) + peersManager := peers.NewManager(store) jobManager := job.NewJobManager(nil, store, peersManager) updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store)) cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { diff --git a/management/server/management_test.go b/management/server/management_test.go index 3ac28cd4ab5..2571a8aaa2e 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" @@ -32,7 +33,6 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -204,7 +204,7 @@ func startServer( AnyTimes() permissionsManager := permissions.NewManager(str) - peersManager := peers.NewManager(str, permissionsManager) + peersManager := peers.NewManager(str) jobManager := job.NewJobManager(nil, str, peersManager) ctx := context.Background() @@ -216,7 +216,7 @@ func startServer( updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) - networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str)), config) accountManager, err := server.BuildManager( context.Background(), @@ -240,7 +240,7 @@ func startServer( t.Fatalf("failed creating an account manager: %v", err) } - groupsManager := groups.NewManager(str, permissionsManager, accountManager) + groupsManager := groups.NewManager(str, accountManager) secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) if err != nil { t.Fatalf("failed creating secrets manager: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ff369355eda..a7fcddec3bc 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -37,7 +37,7 @@ type MockAccountManager struct { GetAccountIDByUserIdFunc func(ctx context.Context, userAuth auth.UserAuth) (string, error) GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) - GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) + GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error @@ -46,7 +46,7 @@ type MockAccountManager struct { AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, groupName, accountID string) (*types.Group, error) SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error @@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { if am.GetGroupByNameFunc != nil { - return am.GetGroupByNameFunc(ctx, groupName, accountID, userID) + return am.GetGroupByNameFunc(ctx, groupName, accountID) } return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") } @@ -773,9 +773,9 @@ func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, user } // GetPeers mocks GetPeers of the AccountManager interface -func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { +func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error) { if am.GetPeersFunc != nil { - return am.GetPeersFunc(ctx, accountID, userID, nameFilter, ipFilter) + return am.GetPeersFunc(ctx, accountID, userID, nameFilter, ipFilter, all) } return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 3d8c78912b7..3f4d97c32dd 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -11,8 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" nbdomain "github.com/netbirdio/netbird/shared/management/domain" @@ -23,27 +21,11 @@ var errInvalidDomainName = errors.New("invalid domain name") // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - newNSGroup := &nbdns.NameServerGroup{ ID: xid.New().String(), AccountID: accountID, @@ -59,7 +41,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { return err } @@ -94,17 +77,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID) if err != nil { return err @@ -141,18 +116,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var nsGroup *nbdns.NameServerGroup var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) if err != nil { return err @@ -184,14 +152,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d10d4464fd2..20e74d502ca 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -15,13 +15,13 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -792,7 +792,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes() permissionsManager := permissions.NewManager(store) - peersManager := peers.NewManager(store, permissionsManager) + peersManager := peers.NewManager(store) ctx := context.Background() @@ -803,7 +803,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca4511..95b96ea1358 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -11,11 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -27,50 +23,32 @@ type Manager interface { } type managerImpl struct { - store store.Store - accountManager account.Manager - permissionsManager permissions.Manager - resourcesManager resources.Manager - routersManager routers.Manager + store store.Store + accountManager account.Manager + resourcesManager resources.Manager + routersManager routers.Manager } type mockManager struct { } -func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager account.Manager) Manager { +func NewManager(store store.Store, resourceManager resources.Manager, routersManager routers.Manager, accountManager account.Manager) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - resourcesManager: resourceManager, - routersManager: routersManager, - accountManager: accountManager, + store: store, + resourcesManager: resourceManager, + routersManager: routersManager, + accountManager: accountManager, } } func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID) } func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - network.ID = xid.New().String() - err = m.store.SaveNetwork(ctx, network) + err := m.store.SaveNetwork(ctx, network) if err != nil { return nil, fmt.Errorf("failed to save network: %w", err) } @@ -81,27 +59,11 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network } func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID) } func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - - _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) + _, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) if err != nil { return nil, fmt.Errorf("failed to get network: %w", err) } @@ -112,14 +74,6 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network } func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - network, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) if err != nil { return fmt.Errorf("failed to get network: %w", err) diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go index 6fb19d157a7..38bec1f0a1f 100644 --- a/management/server/networks/manager_test.go +++ b/management/server/networks/manager_test.go @@ -11,7 +11,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" ) @@ -26,11 +25,10 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) { } t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) groupsManager := groups.NewManagerMock() routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + resourcesManager := resources.NewManager(s, groupsManager, &am, nil) + manager := NewManager(s, resourcesManager, routerManager, &am) networks, err := manager.GetAllNetworks(ctx, accountID, userID) require.NoError(t, err) @@ -38,28 +36,6 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) { require.Equal(t, "testNetworkId", networks[0].ID) } -func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) - groupsManager := groups.NewManagerMock() - routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) - - networks, err := manager.GetAllNetworks(ctx, accountID, userID) - require.Error(t, err) - require.Nil(t, networks) -} - func Test_GetNetworkReturnsNetwork(t *testing.T) { ctx := context.Background() accountID := "testAccountId" @@ -72,40 +48,16 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) { } t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) groupsManager := groups.NewManagerMock() routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + resourcesManager := resources.NewManager(s, groupsManager, &am, nil) + manager := NewManager(s, resourcesManager, routerManager, &am) networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) require.Equal(t, "testNetworkId", networks.ID) } -func Test_GetNetworkReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) - groupsManager := groups.NewManagerMock() - routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) - - network, err := manager.GetNetwork(ctx, accountID, userID, networkID) - require.Error(t, err) - require.Nil(t, network) -} - func Test_CreateNetworkSuccessfully(t *testing.T) { ctx := context.Background() userID := "testAdminId" @@ -120,42 +72,16 @@ func Test_CreateNetworkSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) groupsManager := groups.NewManagerMock() routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + resourcesManager := resources.NewManager(s, groupsManager, &am, nil) + manager := NewManager(s, resourcesManager, routerManager, &am) createdNetwork, err := manager.CreateNetwork(ctx, userID, network) require.NoError(t, err) require.Equal(t, network.Name, createdNetwork.Name) } -func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - userID := "testUserId" - network := &types.Network{ - AccountID: "testAccountId", - Name: "new-network", - } - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) - groupsManager := groups.NewManagerMock() - routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) - - createdNetwork, err := manager.CreateNetwork(ctx, userID, network) - require.Error(t, err) - require.Nil(t, createdNetwork) -} - func Test_DeleteNetworkSuccessfully(t *testing.T) { ctx := context.Background() accountID := "testAccountId" @@ -168,38 +94,15 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) groupsManager := groups.NewManagerMock() routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + resourcesManager := resources.NewManager(s, groupsManager, &am, nil) + manager := NewManager(s, resourcesManager, routerManager, &am) err = manager.DeleteNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) } -func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) - groupsManager := groups.NewManagerMock() - routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) - - err = manager.DeleteNetwork(ctx, accountID, userID, networkID) - require.Error(t, err) -} - func Test_UpdateNetworkSuccessfully(t *testing.T) { ctx := context.Background() userID := "testAdminId" @@ -215,40 +118,12 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) groupsManager := groups.NewManagerMock() routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + resourcesManager := resources.NewManager(s, groupsManager, &am, nil) + manager := NewManager(s, resourcesManager, routerManager, &am) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) require.NoError(t, err) require.Equal(t, network.Name, updatedNetwork.Name) } - -func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - userID := "testUserId" - network := &types.Network{ - AccountID: "testAccountId", - ID: "testNetworkId", - Name: "new-network", - } - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - - am := mock_server.MockAccountManager{} - permissionsManager := permissions.NewManager(s) - groupsManager := groups.NewManagerMock() - routerManager := routers.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil) - manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) - - updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) - require.Error(t, err) - require.Nil(t, updatedNetwork) -} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 86f9b657940..c5fab812f5e 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -12,9 +12,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks/resources/types" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" nbtypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" @@ -33,59 +30,33 @@ type Manager interface { } type managerImpl struct { - store store.Store - permissionsManager permissions.Manager - groupsManager groups.Manager - accountManager account.Manager - serviceManager service.Manager + store store.Store + groupsManager groups.Manager + accountManager account.Manager + serviceManager service.Manager } type mockManager struct { } -func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager { +func NewManager(store store.Store, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - groupsManager: groupsManager, - accountManager: accountManager, - serviceManager: reverseproxyManager, + store: store, + groupsManager: groupsManager, + accountManager: accountManager, + serviceManager: reverseproxyManager, } } func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID) } func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID) } func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get network resources: %w", err) @@ -100,15 +71,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, } func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - - resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs, resource.Enabled) + resource, err := types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs, resource.Enabled) if err != nil { return nil, fmt.Errorf("failed to create new network resource: %w", err) } @@ -168,14 +131,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc } func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { return nil, fmt.Errorf("failed to get network resource: %w", err) @@ -189,14 +144,6 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ } func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - resourceType, domain, prefix, err := types.GetResourceType(resource.Address) if err != nil { return nil, fmt.Errorf("failed to get resource type: %w", err) @@ -314,14 +261,6 @@ func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction stor } func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID) if err != nil { return fmt.Errorf("failed to check if resource is used by service: %w", err) diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index c6d8e7bcc4d..cb8784ad00a 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -11,9 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { @@ -27,41 +25,17 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) require.Len(t, resources, 2) } -func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - - store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) - am := mock_server.MockAccountManager{} - groupsManager := groups.NewManagerMock() - ctrl := gomock.NewController(t) - serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) - - resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, resources) -} func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { ctx := context.Background() accountID := "testAccountId" @@ -72,41 +46,17 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.NoError(t, err) require.Len(t, resources, 2) } -func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - - store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) - am := mock_server.MockAccountManager{} - groupsManager := groups.NewManagerMock() - ctrl := gomock.NewController(t) - serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) - - resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, resources) -} - func Test_GetResourceInNetworkReturnsResources(t *testing.T) { ctx := context.Background() accountID := "testAccountId" @@ -119,43 +69,17 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) require.Equal(t, resourceID, resource.ID) } -func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - resourceID := "testResourceId" - - store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) - am := mock_server.MockAccountManager{} - groupsManager := groups.NewManagerMock() - ctrl := gomock.NewController(t) - serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) - - resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, resources) -} - func Test_CreateResourceSuccessfully(t *testing.T) { ctx := context.Background() userID := "testAdminId" @@ -172,48 +96,18 @@ func Test_CreateResourceSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.NoError(t, err) require.Equal(t, resource.Name, createdResource.Name) } -func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - userID := "testUserId" - resource := &types.NetworkResource{ - AccountID: "testAccountId", - NetworkID: "testNetworkId", - Name: "testResourceId", - Description: "description", - Address: "192.168.1.1", - } - - store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) - am := mock_server.MockAccountManager{} - groupsManager := groups.NewManagerMock() - ctrl := gomock.NewController(t) - serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) - - createdResource, err := manager.CreateResource(ctx, userID, resource) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, createdResource) -} - func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { ctx := context.Background() userID := "testAdminId" @@ -230,12 +124,11 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -258,12 +151,11 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -290,13 +182,12 @@ func Test_UpdateResourceSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.NoError(t, err) @@ -325,12 +216,11 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -357,43 +247,11 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) - - updatedResource, err := manager.UpdateResource(ctx, userID, resource) - require.Error(t, err) - require.Nil(t, updatedResource) -} - -func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - resourceID := "testResourceId" - resource := &types.NetworkResource{ - AccountID: accountID, - NetworkID: networkID, - Name: resourceID, - Description: "new-description", - Address: "1.2.3.0/24", - } - - store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) - am := mock_server.MockAccountManager{} - groupsManager := groups.NewManagerMock() - ctrl := gomock.NewController(t) - serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -412,37 +270,13 @@ func Test_DeleteResourceSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) serviceManager := reverseproxy.NewMockManager(ctrl) serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) + manager := NewManager(store, groupsManager, &am, serviceManager) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) } - -func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - resourceID := "testResourceId" - - store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(store) - am := mock_server.MockAccountManager{} - groupsManager := groups.NewManagerMock() - ctrl := gomock.NewController(t) - serviceManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) - - err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) - require.Error(t, err) -} diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a2c..d861855c158 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -11,9 +11,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" ) @@ -29,43 +26,25 @@ type Manager interface { } type managerImpl struct { - store store.Store - permissionsManager permissions.Manager - accountManager account.Manager + store store.Store + accountManager account.Manager } type mockManager struct { } -func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager { +func NewManager(store store.Store, accountManager account.Manager) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - accountManager: accountManager, + store: store, + accountManager: accountManager, } } func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID) } func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get network routers: %w", err) @@ -80,16 +59,9 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use } func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - var network *networkTypes.Network - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) if err != nil { return fmt.Errorf("failed to get network: %w", err) @@ -125,14 +97,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t } func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID) if err != nil { return nil, fmt.Errorf("failed to get network router: %w", err) @@ -146,16 +110,9 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI } func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - var network *networkTypes.Network - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) if err != nil { return fmt.Errorf("failed to get network: %w", err) @@ -189,16 +146,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t } func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !ok { - return status.NewPermissionDeniedError() - } - var event func() - err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID) if err != nil { return fmt.Errorf("failed to delete network router: %w", err) diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 6be90baa7a9..54622114fbd 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -8,9 +8,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/routers/types" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { @@ -24,9 +22,8 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + manager := NewManager(s, &am) routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -34,27 +31,6 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { require.Equal(t, "testRouterId", routers[0].ID) } -func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) - am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) - - routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, routers) -} - func Test_GetRouterReturnsRouter(t *testing.T) { ctx := context.Background() accountID := "testAccountId" @@ -67,37 +43,14 @@ func Test_GetRouterReturnsRouter(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + manager := NewManager(s, &am) router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) require.Equal(t, "testRouterId", router.ID) } -func Test_GetRouterReturnsPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - resourceID := "testRouterId" - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) - am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) - - router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, router) -} - func Test_CreateRouterSuccessfully(t *testing.T) { ctx := context.Background() userID := "testAdminId" @@ -111,9 +64,8 @@ func Test_CreateRouterSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + manager := NewManager(s, &am) createdRouter, err := manager.CreateRouter(ctx, userID, router) require.NoError(t, err) @@ -124,29 +76,6 @@ func Test_CreateRouterSuccessfully(t *testing.T) { require.Equal(t, router.Masquerade, createdRouter.Masquerade) } -func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - userID := "testUserId" - router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true) - if err != nil { - require.NoError(t, err) - } - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) - am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) - - createdRouter, err := manager.CreateRouter(ctx, userID, router) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, createdRouter) -} - func Test_DeleteRouterSuccessfully(t *testing.T) { ctx := context.Background() accountID := "testAccountId" @@ -159,35 +88,13 @@ func Test_DeleteRouterSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + manager := NewManager(s, &am) err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) require.NoError(t, err) } -func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - accountID := "testAccountId" - userID := "testUserId" - networkID := "testNetworkId" - routerID := "testRouterId" - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) - am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) - - err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) -} - func Test_UpdateRouterSuccessfully(t *testing.T) { ctx := context.Background() userID := "testAdminId" @@ -201,34 +108,10 @@ func Test_UpdateRouterSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + manager := NewManager(s, &am) updatedRouter, err := manager.UpdateRouter(ctx, userID, router) require.NoError(t, err) require.Equal(t, router.Metric, updatedRouter.Metric) } - -func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { - ctx := context.Background() - userID := "testUserId" - router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true) - if err != nil { - require.NoError(t, err) - } - - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - permissionsManager := permissions.NewManager(s) - am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) - - updatedRouter, err := manager.UpdateRouter(ctx, userID, router) - require.Error(t, err) - require.Equal(t, status.NewPermissionDeniedError(), err) - require.Nil(t, updatedRouter) -} diff --git a/management/server/peer.go b/management/server/peer.go index a95ae17a356..9d78c93a945 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -18,8 +18,6 @@ import ( "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/posture" @@ -35,24 +33,18 @@ const remoteJobsMinVer = "0.64.0" // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. -func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { +func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter) if err != nil { return nil, err } - // @note if the user has permission to read peers it shows all account peers - if allowed { + if all || user.IsAdminOrServiceUser() { return accountPeers, nil } @@ -198,15 +190,8 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - var peer *nbpeer.Peer + var err error var settings *types.Settings var peerGroupList []string var peerLabelChanged bool @@ -343,14 +328,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Create) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { return err @@ -418,15 +395,6 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p } func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) { - // todo: Create permissions for job - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) if err != nil { return nil, err @@ -445,14 +413,6 @@ func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, } func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) if err != nil { return nil, err @@ -472,14 +432,6 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) if err != nil { return err @@ -609,16 +561,11 @@ func (am *DefaultAccountManager) handleUserAddedPeer(ctx context.Context, accoun if user.PendingApproval { return status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") } + if temporary && !user.IsAdminOrServiceUser() { + return status.Errorf(status.PermissionDenied, "only admin or service users can add peers") + } - if temporary { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - } else { + if !temporary { config.AccountID = user.AccountID config.GroupsToAdd = user.AutoGroups } @@ -1237,14 +1184,6 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if allowed { - return peer, nil - } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6f8d924fd53..28e13c4cb9b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -30,13 +30,13 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" @@ -737,7 +737,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - peers, err := manager.GetPeers(context.Background(), accountID, someUser, "", "") + peers, err := manager.GetPeers(context.Background(), accountID, someUser, "", "", false) if err != nil { t.Fatal(err) return @@ -943,7 +943,7 @@ func BenchmarkGetPeers(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := manager.GetPeers(context.Background(), accountID, userID, "", "") + _, err := manager.GetPeers(context.Background(), accountID, userID, "", "", true) if err != nil { b.Fatalf("GetPeers failed: %v", err) } @@ -1292,7 +1292,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - peersManager := peers.NewManager(s, permissionsManager) + peersManager := peers.NewManager(s) ctx := context.Background() @@ -1301,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) @@ -1382,7 +1382,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { Return(&types.ExtraSettings{}, nil). AnyTimes() permissionsManager := permissions.NewManager(s) - peersManager := peers.NewManager(s, permissionsManager) + peersManager := peers.NewManager(s) ctx := context.Background() @@ -1391,7 +1391,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) @@ -1540,7 +1540,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - peersManager := peers.NewManager(s, permissionsManager) + peersManager := peers.NewManager(s) ctx := context.Background() @@ -1549,7 +1549,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) @@ -1625,7 +1625,7 @@ func Test_LoginPeer(t *testing.T) { Return(&types.ExtraSettings{}, nil). AnyTimes() permissionsManager := permissions.NewManager(s) - peersManager := peers.NewManager(s, permissionsManager) + peersManager := peers.NewManager(s) ctx := context.Background() @@ -1634,7 +1634,7 @@ func Test_LoginPeer(t *testing.T) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) diff --git a/management/server/policy.go b/management/server/policy.go index 48297ca11e8..888969a02e0 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -7,8 +7,6 @@ import ( "github.com/rs/xid" "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -19,32 +17,13 @@ import ( // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID) } // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) { - operation := operations.Create - if !create { - operation = operations.Update - } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - var isUpdate = policy.ID != "" + var err error var updateAccountPeers bool var action = activity.PolicyAdded var unchanged bool @@ -55,6 +34,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } + // TODO: split into separate create and update functions to avoid the isUpdate check if isUpdate { if policy.Equal(existingPolicy) { logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID) @@ -104,16 +84,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var policy *types.Policy var updateAccountPeers bool + var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) @@ -147,14 +120,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po // ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9562487c024..799cbf86f2d 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -7,40 +7,19 @@ import ( "github.com/rs/xid" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID) } // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { - operation := operations.Create - if !create { - operation = operations.Update - } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - var updateAccountPeers bool + var err error var isUpdate = postureChecks.ID != "" var action = activity.PostureCheckCreated @@ -49,6 +28,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } + // TODO: split into separate create and update functions to avoid the isUpdate check if isUpdate { updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) if err != nil { @@ -84,15 +64,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var postureChecks *posture.Checks + var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID) @@ -121,14 +94,6 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun // ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7f0a48dc766..4c5b9bda673 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -32,14 +32,6 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { } t.Run("Generic posture check flow", func(t *testing.T) { - // regular users can not create checks - _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}, true) - assert.Error(t, err) - - // regular users cannot list check - _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) - assert.Error(t, err) - // should be possible to create posture check with uniq name postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, @@ -80,10 +72,6 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck, true) assert.NoError(t, err) - // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) - assert.Error(t, err) - // admin should be able to delete posture checks err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) diff --git a/management/server/route.go b/management/server/route.go index 2b4f11d052f..64ff4ee9144 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -10,8 +10,6 @@ import ( "github.com/rs/xid" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" @@ -21,14 +19,6 @@ import ( // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID)) } @@ -134,20 +124,13 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { // CreateRoute creates and saves a new route func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } var newRoute *route.Route var updateAccountPeers bool + var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { newRoute = &route.Route{ @@ -199,15 +182,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var oldRoute *route.Route + var err error var oldRouteAffectsPeers bool var newRouteAffectsPeers bool @@ -253,16 +229,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var route *route.Route var updateAccountPeers bool + var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) @@ -296,14 +265,6 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) } diff --git a/management/server/route_test.go b/management/server/route_test.go index 91b2cf98269..0811f2ca5fd 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -27,7 +28,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -1291,7 +1291,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. Return(&types.ExtraSettings{}, nil) permissionsManager := permissions.NewManager(store) - peersManager := peers.NewManager(store, permissionsManager) + peersManager := peers.NewManager(store) ctx := context.Background() @@ -1302,7 +1302,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{}) am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 74af0a3ef4e..ef8b9c95d98 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -6,15 +6,10 @@ import ( "context" "fmt" - "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" - "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -35,16 +30,14 @@ type managerImpl struct { store store.Store extraSettingsManager extra_settings.Manager userManager users.Manager - permissionsManager permissions.Manager idpConfig IdpConfig } -func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager { +func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, idpConfig IdpConfig) Manager { return &managerImpl{ store: store, extraSettingsManager: extraSettingsManager, userManager: userManager, - permissionsManager: permissionsManager, idpConfig: idpConfig, } } @@ -54,16 +47,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { - if userID != activity.SystemInitiator { - ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !ok { - return nil, status.NewPermissionDeniedError() - } - } - extraSettings, err := m.extraSettingsManager.GetExtraSettings(ctx, accountID) if err != nil { return nil, fmt.Errorf("get extra settings: %w", err) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 8d05098717f..bde39cff2ea 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -8,8 +8,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" @@ -56,19 +54,12 @@ type SetupKeyUpdateOperation struct { func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - var setupKey *types.SetupKey var plainKey string var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } @@ -105,19 +96,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - var oldKey *types.SetupKey var newKey *types.SetupKey var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } @@ -162,27 +146,11 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID) if err != nil { return nil, err @@ -198,17 +166,10 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use // DeleteSetupKey removes the setup key from the account func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - var deletedSetupKey *types.SetupKey - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyID) if err != nil { return err diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0ff57b75219..ee180c8ac4d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -943,8 +943,8 @@ func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength Locki } // DeleteUserInvite deletes a user invite by its ID -func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error { - result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID) +func (s *SqlStore) DeleteUserInvite(ctx context.Context, accountID, inviteID string) error { + result := s.db.Delete(&types.UserInviteRecord{}, accountAndIDQueryCondition, accountID, inviteID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete user invite from store") @@ -2917,6 +2917,10 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, } } + if err = copyZonesAndRecords(sqliteStore, store); err != nil { + return nil, err + } + return store, nil } @@ -2983,9 +2987,28 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s } } + if err = copyZonesAndRecords(sqliteStore, store); err != nil { + return nil, err + } + return store, nil } +func copyZonesAndRecords(src, dst *SqlStore) error { + var srcZones []*zones.Zone + if err := src.db.Preload("Records").Find(&srcZones).Error; err != nil { + return fmt.Errorf("failed to read zones from source store: %w", err) + } + + for _, zone := range srcZones { + if err := dst.db.Create(zone).Error; err != nil { + return fmt.Errorf("failed to copy zone %s: %w", zone.ID, err) + } + } + + return nil +} + func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { tx := s.db if lockStrength != LockingStrengthNone { diff --git a/management/server/store/sql_store_user_invite_test.go b/management/server/store/sql_store_user_invite_test.go index fb6934a2e36..206f7214a7a 100644 --- a/management/server/store/sql_store_user_invite_test.go +++ b/management/server/store/sql_store_user_invite_test.go @@ -298,7 +298,7 @@ func TestSqlStore_DeleteUserInvite(t *testing.T) { require.NoError(t, err) // Delete the invite - err = store.DeleteUserInvite(ctx, invite.ID) + err = store.DeleteUserInvite(ctx, invite.AccountID, invite.ID) require.NoError(t, err) // Verify invite is deleted @@ -346,7 +346,7 @@ func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) { ctx := context.Background() // Deleting a non-existent invite should not return an error - err := store.DeleteUserInvite(ctx, "non-existent-invite-id") + err := store.DeleteUserInvite(ctx, "non-existent-account", "non-existent-invite-id") require.NoError(t, err) }) } diff --git a/management/server/store/store.go b/management/server/store/store.go index 0d8b0678a99..0207ea55572 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -103,7 +103,7 @@ type Store interface { GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) - DeleteUserInvite(ctx context.Context, inviteID string) error + DeleteUserInvite(ctx context.Context, accountID, inviteID string) error GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index beee13d9631..976e3088d8b 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) } + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -673,17 +674,17 @@ func (mr *MockStoreMockRecorder) DeleteUser(ctx, accountID, userID interface{}) } // DeleteUserInvite mocks base method. -func (m *MockStore) DeleteUserInvite(ctx context.Context, inviteID string) error { +func (m *MockStore) DeleteUserInvite(ctx context.Context, accountID, inviteID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteUserInvite", ctx, inviteID) + ret := m.ctrl.Call(m, "DeleteUserInvite", ctx, accountID, inviteID) ret0, _ := ret[0].(error) return ret0 } // DeleteUserInvite indicates an expected call of DeleteUserInvite. -func (mr *MockStoreMockRecorder) DeleteUserInvite(ctx, inviteID interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) DeleteUserInvite(ctx, accountID, inviteID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserInvite", reflect.TypeOf((*MockStore)(nil).DeleteUserInvite), ctx, inviteID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserInvite", reflect.TypeOf((*MockStore)(nil).DeleteUserInvite), ctx, accountID, inviteID) } // DeleteZone mocks base method. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 00ba29b7f49..d652b2d567b 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -85,9 +85,9 @@ func setupTestAccount() *Account { }, Groups: map[string]*Group{ "groupAll": { - ID: "groupAll", - Name: "All", - Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"}, + ID: "groupAll", + Name: "All", + Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"}, Issued: GroupIssuedAPI, }, "group1": { diff --git a/management/server/user.go b/management/server/user.go index c1f984f2fcf..c971b699e86 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -18,8 +18,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -29,14 +27,6 @@ import ( // createServiceUser creates a new service user under the given account. func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - if role == types.UserRoleOwner { return nil, status.NewServiceUserRoleInvalidError() } @@ -46,7 +36,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI newUser.AccountID = accountID log.WithContext(ctx).Debugf("New User: %v", newUser) - if err = am.Store.SaveUser(ctx, newUser); err != nil { + if err := am.Store.SaveUser(ctx, newUser); err != nil { return nil, err } @@ -84,14 +74,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Users, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err @@ -270,12 +252,21 @@ func (am *DefaultAccountManager) UpdateUserPassword(ctx context.Context, account return status.Errorf(status.InvalidArgument, "new password is required") } + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return err + } + + if targetUser.AccountID != accountID { + return status.NewUserNotFoundError(targetUserID) + } + embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager) if !ok { return status.Errorf(status.Internal, "failed to get embedded IdP manager") } - err := embeddedIdp.UpdateUserPassword(ctx, currentUserID, targetUserID, oldPassword, newPassword) + err = embeddedIdp.UpdateUserPassword(ctx, currentUserID, targetUserID, oldPassword, newPassword) if err != nil { return status.Errorf(status.InvalidArgument, "failed to update password: %v", err) } @@ -305,19 +296,15 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return err } + if targetUser.AccountID != accountID { + return status.NewUserNotFoundError(targetUserID) + } + if targetUser.Role == types.UserRoleOwner { return status.NewOwnerDeletePermissionError() } @@ -355,14 +342,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - // check if the user is already registered with this ID user, err := am.lookupUserInCache(ctx, targetUserID, accountID) if err != nil { @@ -399,14 +378,6 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err @@ -443,14 +414,6 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err @@ -486,14 +449,6 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string // GetPAT returns a specific PAT from a user func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err @@ -517,14 +472,6 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i // GetAllPATs returns all PATs for a user func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err @@ -574,13 +521,6 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, nil //nolint:nilnil } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) // TODO: split by Create and Update - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -977,12 +917,8 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, initiatorUserID string) (map[string]*types.UserInfo, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - var user *types.User + var err error if initiatorUserID != activity.SystemInitiator { result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { @@ -991,9 +927,15 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun user = result } + // Permission checks are now handled by the HTTP middleware via WithPermission wrapper + // This internal method is called from authenticated/authorized handlers accountUsers := []*types.User{} + + // Determine if user has full access based on their role + hasFullAccess := user.HasAdminPower() || user.IsServiceUser + switch { - case allowed: + case hasFullAccess: start := time.Now() accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1194,14 +1136,6 @@ func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUs // If an error occurs while deleting the user, the function skips it and continues deleting other users. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err @@ -1392,9 +1326,8 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut return nil, status.NewPermissionDeniedError() } - if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { - return nil, err - } + // Permission checks are now handled by the HTTP middleware via WithPermission wrapper + // User account association is already validated above by GetUserByUserID settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1421,14 +1354,6 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut // ApproveUser approves a user that is pending approval func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return nil, err @@ -1462,14 +1387,6 @@ func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, ini // RejectUser rejects a user that is pending approval by deleting them func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return err @@ -1508,14 +1425,6 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID return nil, err } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - // Check if user already exists in NetBird DB existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1626,14 +1535,6 @@ func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -1716,7 +1617,7 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa if err := transaction.SaveUser(ctx, newUser); err != nil { return fmt.Errorf("failed to save user: %w", err) } - if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil { + if err := transaction.DeleteUserInvite(ctx, invite.AccountID, invite.ID); err != nil { return fmt.Errorf("failed to delete invite: %w", err) } return nil @@ -1740,14 +1641,6 @@ func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accou return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - // Get existing invite existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID) if err != nil { @@ -1802,20 +1695,12 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID) if err != nil { return err } - if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil { + if err := am.Store.DeleteUserInvite(ctx, accountID, inviteID); err != nil { return err } diff --git a/management/server/user_invite_test.go b/management/server/user_invite_test.go index 6256ed44a29..e3a66522123 100644 --- a/management/server/user_invite_test.go +++ b/management/server/user_invite_test.go @@ -9,9 +9,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" @@ -163,22 +163,6 @@ func TestCreateUserInvite_ExistingUserEmail(t *testing.T) { assert.Equal(t, status.UserAlreadyExists, sErr.Type()) } -func TestCreateUserInvite_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - invite := &types.UserInfo{ - Email: "newuser@test.com", - Name: "New User", - Role: "user", - AutoGroups: []string{}, - } - - // Regular user should not be able to create invites - _, err := am.CreateUserInvite(context.Background(), testAccountID, testRegularUserID, invite, 0) - require.Error(t, err) -} - func TestCreateUserInvite_InvalidEmail(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() @@ -412,14 +396,6 @@ func TestListUserInvites_Empty(t *testing.T) { assert.Len(t, invites, 0) } -func TestListUserInvites_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - _, err := am.ListUserInvites(context.Background(), testAccountID, testRegularUserID) - require.Error(t, err) -} - func TestRegenerateUserInvite_Success(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() @@ -470,26 +446,6 @@ func TestRegenerateUserInvite_NotFound(t *testing.T) { assert.Equal(t, status.NotFound, sErr.Type()) } -func TestRegenerateUserInvite_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - // Create an invite first - invite := &types.UserInfo{ - Email: "newuser@test.com", - Name: "New User", - Role: "user", - AutoGroups: []string{}, - } - - result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) - require.NoError(t, err) - - // Regular user should not be able to regenerate - _, err = am.RegenerateUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID, 0) - require.Error(t, err) -} - func TestDeleteUserInvite_Success(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() @@ -531,26 +487,6 @@ func TestDeleteUserInvite_NotFound(t *testing.T) { assert.Equal(t, status.NotFound, sErr.Type()) } -func TestDeleteUserInvite_PermissionDenied(t *testing.T) { - am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) - defer cleanup() - - // Create an invite first - invite := &types.UserInfo{ - Email: "newuser@test.com", - Name: "New User", - Role: "user", - AutoGroups: []string{}, - } - - result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) - require.NoError(t, err) - - // Regular user should not be able to delete - err = am.DeleteUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID) - require.Error(t, err) -} - func TestDeleteUserInvite_WrongAccount(t *testing.T) { am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) defer cleanup() diff --git a/management/server/user_test.go b/management/server/user_test.go index 8fdfbd6339b..4f4fc9fddda 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -13,10 +13,10 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles" nbcache "github.com/netbirdio/netbird/management/server/cache" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/shared/auth" @@ -775,6 +775,52 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { } } +func TestUser_DeleteUser_CrossAccountRejected(t *testing.T) { + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + // Create two accounts with users + account1 := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) + targetServiceUser := &types.User{ + Id: mockServiceUserID, + IsServiceUser: true, + ServiceUserName: mockServiceUserName, + } + account1.Users[mockServiceUserID] = targetServiceUser + + otherAccountID := "otherAccountID" + otherUserID := "otherUserID" + account2 := newAccountWithId(context.Background(), otherAccountID, otherUserID, "", "", "", false) + + err = s.SaveAccount(context.Background(), account1) + if err != nil { + t.Fatalf("Error when saving account1: %s", err) + } + err = s.SaveAccount(context.Background(), account2) + if err != nil { + t.Fatalf("Error when saving account2: %s", err) + } + + permissionsManager := permissions.NewManager(s) + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + } + + // otherUserID (from account2) tries to delete mockServiceUserID (from account1) + err = am.DeleteUser(context.Background(), otherAccountID, otherUserID, mockServiceUserID) + assert.Error(t, err, "cross-account user deletion should be rejected") + + // Verify the target user still exists + account, err := s.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err) + assert.NotNil(t, account.Users[mockServiceUserID], "target user should not have been deleted") +} + func TestUser_DeleteUser_SelfDelete(t *testing.T) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) if err != nil { @@ -1426,13 +1472,14 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: ownerUserID, Domain: "netbird.io"}) + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: tc.name, Domain: "netbird.io"}) if err != nil { t.Fatal(err) } // create other users account.Users[regularUserID] = types.NewRegularUser(regularUserID, "", "") + account.Users[ownerUserID] = types.NewOwnerUser(ownerUserID, "", "") account.Users[adminUserID] = types.NewAdminUser(adminUserID) account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) @@ -1705,11 +1752,6 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, - { - name: "not part of account", - userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, - expectedErr: status.NewUserNotPartOfAccountError(), - }, { name: "blocked", userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, @@ -1737,7 +1779,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, }, - Permissions: mergeRolePermissions(roles.Owner), + Permissions: mergeRolePermissions(roles2.Owner), }, }, { @@ -1756,7 +1798,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, }, - Permissions: mergeRolePermissions(roles.User), + Permissions: mergeRolePermissions(roles2.User), }, }, { @@ -1775,7 +1817,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, }, - Permissions: mergeRolePermissions(roles.Admin), + Permissions: mergeRolePermissions(roles2.Admin), }, }, { @@ -1794,7 +1836,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, }, - Permissions: mergeRolePermissions(roles.User), + Permissions: mergeRolePermissions(roles2.User), Restricted: true, }, }, @@ -1815,7 +1857,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, }, - Permissions: mergeRolePermissions(roles.User), + Permissions: mergeRolePermissions(roles2.User), Restricted: false, }, }, @@ -1836,7 +1878,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, }, - Permissions: mergeRolePermissions(roles.Owner), + Permissions: mergeRolePermissions(roles2.Owner), }, }, } @@ -1846,7 +1888,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { result, err := am.GetCurrentUserInfo(context.Background(), tc.userAuth) if tc.expectedErr != nil { - assert.Equal(t, err, tc.expectedErr) + assert.Equal(t, tc.expectedErr, err) return } @@ -1856,8 +1898,8 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { } } -func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { - permissions := roles.Permissions{} +func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions { + permissions := roles2.Permissions{} for k := range modules.All { if rolePermissions, ok := role.Permissions[k]; ok { @@ -1911,22 +1953,6 @@ func TestApproveUser(t *testing.T) { _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) require.Error(t, err) assert.Contains(t, err.Error(), "not pending approval") - - // Test approval by non-admin should fail - regularUser := types.NewRegularUser("regular-user", "", "") - regularUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), regularUser) - require.NoError(t, err) - - pendingUser2 := types.NewRegularUser("pending-user-2", "", "") - pendingUser2.AccountID = account.Id - pendingUser2.Blocked = true - pendingUser2.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser2) - require.NoError(t, err) - - _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) - require.Error(t, err) } func TestRejectUser(t *testing.T) { @@ -1971,17 +1997,6 @@ func TestRejectUser(t *testing.T) { err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id) require.Error(t, err) assert.Contains(t, err.Error(), "not pending approval") - - // Test rejection by non-admin should fail - pendingUser2 := types.NewRegularUser("pending-user-2", "", "") - pendingUser2.AccountID = account.Id - pendingUser2.Blocked = true - pendingUser2.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser2) - require.NoError(t, err) - - err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) - require.Error(t, err) } func TestUser_Operations_WithEmbeddedIDP(t *testing.T) { diff --git a/management/server/users/user.go b/management/server/users/user.go index 2f278827184..e966a036574 100644 --- a/management/server/users/user.go +++ b/management/server/users/user.go @@ -1,7 +1,7 @@ package users import ( - "github.com/netbirdio/netbird/management/server/permissions/roles" + "github.com/netbirdio/netbird/management/internals/modules/permissions/roles" "github.com/netbirdio/netbird/management/server/types" ) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d9a1a7d6525..84363e6588a 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,7 +18,9 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/permissions" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -35,7 +37,6 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -92,7 +93,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(true, nil). AnyTimes() - peersManger := peers.NewManager(store, permissionsManagerMock) + peersManger := peers.NewManager(store) settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersManger)