diff --git a/flask_appbuilder/menu.py b/flask_appbuilder/menu.py index d133edc25a..b45f5e18fe 100644 --- a/flask_appbuilder/menu.py +++ b/flask_appbuilder/menu.py @@ -1,3 +1,5 @@ +from typing import List + from flask import current_app, url_for from .api import BaseApi, expose @@ -6,22 +8,12 @@ class MenuItem(object): - name = "" - href = "" - icon = "" - label = "" - baseview = None - childs = [] - def __init__(self, name, href="", icon="", label="", childs=None, baseview=None): self.name = name self.href = href self.icon = icon self.label = label - if self.childs: - self.childs = childs - else: - self.childs = [] + self.childs = childs or [] self.baseview = baseview def get_url(self): @@ -29,9 +21,7 @@ def get_url(self): if not self.baseview: return "" else: - return url_for( - "{}.{}".format(self.baseview.endpoint, self.baseview.default_view) - ) + return url_for(f"{self.baseview.endpoint}.{self.baseview.default_view}") else: try: return url_for(self.href) @@ -43,8 +33,6 @@ def __repr__(self): class Menu(object): - menu = None - def __init__(self, reverse=True, extra_classes=""): self.menu = [] if reverse: @@ -58,14 +46,27 @@ def reverse(self): def get_list(self): return self.menu + def get_flat_name_list(self, menu: "Menu" = None, result: List = None) -> List: + menu = menu or self.menu + result = result or [] + for item in menu: + result.append(item.name) + if item.childs: + result.extend(self.get_flat_name_list(menu=item.childs, result=result)) + return result + def get_data(self, menu=None): menu = menu or self.menu ret_list = [] + allowed_menus = current_app.appbuilder.sm.get_user_menu_access( + self.get_flat_name_list() + ) + for i, item in enumerate(menu): if item.name == '-' and not i == len(menu) - 1: ret_list.append('-') - elif not current_app.appbuilder.sm.has_access("menu_access", item.name): + elif item.name not in allowed_menus: continue elif item.childs: ret_list.append({ diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 3d27a9975a..e208cf4e9b 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -3,7 +3,7 @@ import json import logging import re -from typing import Dict, List +from typing import Dict, List, Set from flask import g, session, url_for from flask_babel import lazy_gettext as _ @@ -419,6 +419,13 @@ def openid_providers(self): def oauth_providers(self): return self.appbuilder.get_app.config["OAUTH_PROVIDERS"] + @property + def current_user(self): + if current_user.is_authenticated: + return g.user + elif current_user_jwt: + return current_user_jwt + def oauth_user_info_getter(self, f): """ Decorator function to be the OAuth user info getter @@ -1085,6 +1092,45 @@ def _has_view_access( db_role_ids, ) + def _get_user_permission_view_menus( + self, + user: object, + permission_name: str, + view_menus_name: List[str] + ) -> Set[str]: + """ + Return a set of view menu names with a certain permission name + that a user has access to. Mainly used to fetch all menu permissions + on a single db call, will also check public permissions and builtin roles + """ + db_role_ids = list() + if user is None: + # include public role + roles = [self.get_public_role()] + else: + roles = user.roles + # First check against builtin (statically configured) roles + # because no database query is needed + result = set() + for role in roles: + if role.name in self.builtin_roles: + for view_menu_name in view_menus_name: + if self._has_access_builtin_roles( + role, + permission_name, + view_menu_name + ): + result.add(view_menu_name) + else: + db_role_ids.append(role.id) + # Then check against database-stored roles + pvms_names = [ + pvm.view_menu.name + for pvm in self.find_roles_permission_view_menus(permission_name, db_role_ids) + ] + result.update(pvms_names) + return result + def has_access(self, permission_name, view_name): """ Check if current user or public has access to view or menu @@ -1096,6 +1142,17 @@ def has_access(self, permission_name, view_name): else: return self.is_item_public(permission_name, view_name) + def get_user_menu_access(self, menu_names: List[str] = None) -> Set[str]: + if current_user.is_authenticated: + return self._get_user_permission_view_menus( + g.user, "menu_access", view_menus_name=menu_names) + elif current_user_jwt: + return self._get_user_permission_view_menus( + current_user_jwt, "menu_access", view_menus_name=menu_names) + else: + return self._get_user_permission_view_menus( + None, "menu_access", view_menus_name=menu_names) + def add_permissions_view(self, base_permissions, view_menu): """ Adds a permission on a view menu to the backend @@ -1441,6 +1498,12 @@ def get_all_roles(self): ---------------------------- """ + def get_public_role(self): + """ + returns all permissions from public role + """ + raise NotImplementedError + def get_public_permissions(self): """ returns all permissions from public role @@ -1453,6 +1516,13 @@ def find_permission(self, name): """ raise NotImplementedError + def find_roles_permission_view_menus( + self, + permission_name: str, + role_ids: List[int], + ): + raise NotImplementedError + def exist_permission_on_roles( self, view_name: str, diff --git a/flask_appbuilder/security/sqla/manager.py b/flask_appbuilder/security/sqla/manager.py index f329342a9d..7982863120 100644 --- a/flask_appbuilder/security/sqla/manager.py +++ b/flask_appbuilder/security/sqla/manager.py @@ -238,7 +238,6 @@ def add_role(self, name: str) -> Optional[Role]: def update_role(self, pk, name: str) -> Optional[Role]: role = self.get_session.query(self.role_model).get(pk) - print(f"Update role {role} {pk}") if not role: return try: @@ -257,12 +256,15 @@ def find_role(self, name): def get_all_roles(self): return self.get_session.query(self.role_model).all() - def get_public_permissions(self): - role = ( + def get_public_role(self): + return ( self.get_session.query(self.role_model) .filter_by(name=self.auth_role_public) .first() ) + + def get_public_permissions(self): + role = self.get_public_role() if role: return role.permissions return [] @@ -314,6 +316,24 @@ def exist_permission_on_roles( return self.appbuilder.get_session.query(literal(True)).filter(q).scalar() return self.appbuilder.get_session.query(q).scalar() + def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[int]): + return ( + self.appbuilder.get_session.query(self.permissionview_model) + .join( + assoc_permissionview_role, + and_( + (self.permissionview_model.id == + assoc_permissionview_role.c.permission_view_id), + ), + ) + .join(self.role_model) + .join(self.permission_model) + .join(self.viewmenu_model) + .filter( + self.permission_model.name == permission_name, + self.role_model.id.in_(role_ids)) + ).all() + def add_permission(self, name): """ Adds a permission to the backend, model permission diff --git a/flask_appbuilder/tests/test_menu.py b/flask_appbuilder/tests/test_menu.py index f9c6d3f8ea..5a3945b77c 100644 --- a/flask_appbuilder/tests/test_menu.py +++ b/flask_appbuilder/tests/test_menu.py @@ -5,14 +5,10 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from .base import FABTestCase +from .const import MAX_PAGE_SIZE, PASSWORD_ADMIN, USERNAME_ADMIN from .sqla.models import Model1 -log = logging.getLogger(__name__) - -DEFAULT_ADMIN_USER = "admin" -DEFAULT_ADMIN_PASSWORD = "general" -LIMITED_USER = "user1" -LIMITED_USER_PASSWORD = "user1" +log = logging.getLogger(__name__) class FlaskTestCase(FABTestCase): @@ -23,13 +19,8 @@ def setUp(self): self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) - self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" - self.app.config["CSRF_ENABLED"] = False - self.app.config["SECRET_KEY"] = "thisismyscretkey" - self.app.config["WTF_CSRF_ENABLED"] = False - self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - - logging.basicConfig(level=logging.ERROR) + self.app.config.from_object("flask_appbuilder.tests.config_api") + self.app.config["FAB_API_MAX_PAGE_SIZE"] = MAX_PAGE_SIZE self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) @@ -38,35 +29,6 @@ class Model1View(ModelView): datamodel = SQLAInterface(Model1) self.appbuilder.add_view(Model1View, "Model1") - role_admin = self.appbuilder.sm.find_role("Admin") - self.appbuilder.sm.add_user( - DEFAULT_ADMIN_USER, - "admin", - "user", - "admin@fab.org", - role_admin, - DEFAULT_ADMIN_PASSWORD - ) - - role_limited = self.appbuilder.sm.add_role("LimitedUser") - pvm = self.appbuilder.sm.find_permission_view_menu( - "menu_access", - "Model1" - ) - self.appbuilder.sm.add_permission_role(role_limited, pvm) - pvm = self.appbuilder.sm.find_permission_view_menu( - "can_get", - "MenuApi" - ) - self.appbuilder.sm.add_permission_role(role_limited, pvm) - self.appbuilder.sm.add_user( - LIMITED_USER, - "user1", - "user1", - "user1@fab.org", - role_limited, - LIMITED_USER_PASSWORD - ) def tearDown(self): self.appbuilder = None @@ -74,31 +36,87 @@ def tearDown(self): self.db = None log.debug("TEAR DOWN") - def test_menu_api(self): + def test_menu_access_denied(self): """ - REST Api: Test menu data + REST Api: Test menu logged out access denied + :return: """ - uri = '/api/v1/menu/' + uri = "/api/v1/menu/" client = self.app.test_client() - # as loged out user + # as logged out user rv = client.get(uri) self.assertEqual(rv.status_code, 401) + def test_menu_api(self): + """ + REST Api: Test menu data + """ + uri = "/api/v1/menu/" + client = self.app.test_client() + + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = rv.data.decode("utf-8") + self.assertIn("Security", data) + self.assertIn("Model1", data) + + def test_menu_api_limited(self): + """ + REST Api: Test limited menu data + """ + limited_user = "user1" + limited_password = "user1" + limited_role = "Limited" + + role = self.appbuilder.sm.add_role(limited_role) + pvm = self.appbuilder.sm.find_permission_view_menu("menu_access", "Model1") + self.appbuilder.sm.add_permission_role(role, pvm) + pvm = self.appbuilder.sm.find_permission_view_menu("can_get", "MenuApi") + self.appbuilder.sm.add_permission_role(role, pvm) + self.appbuilder.sm.add_user( + limited_user, "user1", "user1", "user1@fab.org", role, limited_password + ) + + uri = "/api/v1/menu/" + client = self.app.test_client() # as limited user - token = self.login(client, LIMITED_USER, LIMITED_USER_PASSWORD) + token = self.login(client, limited_user, limited_password) rv = self.auth_client_get(client, token, uri) self.assertEqual(rv.status_code, 200) - data = rv.data.decode('utf-8') + data = rv.data.decode("utf-8") self.assertNotIn("Security", data) self.assertIn("Model1", data) self.browser_logout(client) - # as admin - token = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) - rv = self.auth_client_get(client, token, uri) + # Revert test data + self.appbuilder.get_session.delete( + self.appbuilder.sm.find_user(username=limited_user) + ) + self.appbuilder.get_session.delete(self.appbuilder.sm.find_role(limited_role)) + self.appbuilder.get_session.commit() + + def test_menu_api_public(self): + """ + REST Api: Test public menu data + """ + role = self.appbuilder.sm.find_role("Public") + pvm = self.appbuilder.sm.find_permission_view_menu("menu_access", "Model1") + self.appbuilder.sm.add_permission_role(role, pvm) + pvm = self.appbuilder.sm.find_permission_view_menu("can_get", "MenuApi") + self.appbuilder.sm.add_permission_role(role, pvm) + + uri = "/api/v1/menu/" + client = self.app.test_client() + # as limited user + rv = client.get(uri) self.assertEqual(rv.status_code, 200) - data = rv.data.decode('utf-8') - self.assertIn("Security", data) + data = rv.data.decode("utf-8") self.assertIn("Model1", data) + + # Revert test data + role = self.appbuilder.sm.find_role("Public") + role.permissions = [] + self.appbuilder.get_session.commit()