Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions flask_appbuilder/menu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from flask import current_app, url_for

from .api import BaseApi, expose
Expand All @@ -6,32 +8,20 @@


class MenuItem(object):
name = ""
href = ""
icon = ""
label = ""
baseview = None
childs = []

def __init__(self, name, href="", icon="", label="", childs=None, baseview=None):
self.name = name
self.href = href
self.icon = icon
self.label = label
if self.childs:
self.childs = childs
else:
self.childs = []
self.childs = childs or []
self.baseview = baseview

def get_url(self):
if not self.href:
if not self.baseview:
return ""
else:
return url_for(
"{}.{}".format(self.baseview.endpoint, self.baseview.default_view)
)
return url_for(f"{self.baseview.endpoint}.{self.baseview.default_view}")
else:
try:
return url_for(self.href)
Expand All @@ -43,8 +33,6 @@ def __repr__(self):


class Menu(object):
menu = None

def __init__(self, reverse=True, extra_classes=""):
self.menu = []
if reverse:
Expand All @@ -58,14 +46,27 @@ def reverse(self):
def get_list(self):
return self.menu

def get_flat_name_list(self, menu: "Menu" = None, result: List = None) -> List:
menu = menu or self.menu
result = result or []
for item in menu:
result.append(item.name)
if item.childs:
result.extend(self.get_flat_name_list(menu=item.childs, result=result))
return result

def get_data(self, menu=None):
menu = menu or self.menu
ret_list = []

allowed_menus = current_app.appbuilder.sm.get_user_menu_access(
self.get_flat_name_list()
)

for i, item in enumerate(menu):
if item.name == '-' and not i == len(menu) - 1:
ret_list.append('-')
elif not current_app.appbuilder.sm.has_access("menu_access", item.name):
elif item.name not in allowed_menus:
continue
elif item.childs:
ret_list.append({
Expand Down
72 changes: 71 additions & 1 deletion flask_appbuilder/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import re
from typing import Dict, List
from typing import Dict, List, Set

from flask import g, session, url_for
from flask_babel import lazy_gettext as _
Expand Down Expand Up @@ -419,6 +419,13 @@ def openid_providers(self):
def oauth_providers(self):
return self.appbuilder.get_app.config["OAUTH_PROVIDERS"]

@property
def current_user(self):
if current_user.is_authenticated:
return g.user
elif current_user_jwt:
return current_user_jwt

def oauth_user_info_getter(self, f):
"""
Decorator function to be the OAuth user info getter
Expand Down Expand Up @@ -1085,6 +1092,45 @@ def _has_view_access(
db_role_ids,
)

def _get_user_permission_view_menus(
self,
user: object,
permission_name: str,
view_menus_name: List[str]
) -> Set[str]:
"""
Return a set of view menu names with a certain permission name
that a user has access to. Mainly used to fetch all menu permissions
on a single db call, will also check public permissions and builtin roles
"""
db_role_ids = list()
if user is None:
# include public role
roles = [self.get_public_role()]
else:
roles = user.roles
# First check against builtin (statically configured) roles
# because no database query is needed
result = set()
for role in roles:
if role.name in self.builtin_roles:
for view_menu_name in view_menus_name:
if self._has_access_builtin_roles(
role,
permission_name,
view_menu_name
):
result.add(view_menu_name)
else:
db_role_ids.append(role.id)
# Then check against database-stored roles
pvms_names = [
pvm.view_menu.name
for pvm in self.find_roles_permission_view_menus(permission_name, db_role_ids)
]
result.update(pvms_names)
return result

def has_access(self, permission_name, view_name):
"""
Check if current user or public has access to view or menu
Expand All @@ -1096,6 +1142,17 @@ def has_access(self, permission_name, view_name):
else:
return self.is_item_public(permission_name, view_name)

def get_user_menu_access(self, menu_names: List[str] = None) -> Set[str]:
if current_user.is_authenticated:
return self._get_user_permission_view_menus(
g.user, "menu_access", view_menus_name=menu_names)
elif current_user_jwt:
return self._get_user_permission_view_menus(
current_user_jwt, "menu_access", view_menus_name=menu_names)
else:
return self._get_user_permission_view_menus(
None, "menu_access", view_menus_name=menu_names)

def add_permissions_view(self, base_permissions, view_menu):
"""
Adds a permission on a view menu to the backend
Expand Down Expand Up @@ -1441,6 +1498,12 @@ def get_all_roles(self):
----------------------------
"""

def get_public_role(self):
"""
returns all permissions from public role
"""
raise NotImplementedError

def get_public_permissions(self):
"""
returns all permissions from public role
Expand All @@ -1453,6 +1516,13 @@ def find_permission(self, name):
"""
raise NotImplementedError

def find_roles_permission_view_menus(
self,
permission_name: str,
role_ids: List[int],
):
raise NotImplementedError

def exist_permission_on_roles(
self,
view_name: str,
Expand Down
26 changes: 23 additions & 3 deletions flask_appbuilder/security/sqla/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def add_role(self, name: str) -> Optional[Role]:

def update_role(self, pk, name: str) -> Optional[Role]:
role = self.get_session.query(self.role_model).get(pk)
print(f"Update role {role} {pk}")
if not role:
return
try:
Expand All @@ -257,12 +256,15 @@ def find_role(self, name):
def get_all_roles(self):
return self.get_session.query(self.role_model).all()

def get_public_permissions(self):
role = (
def get_public_role(self):
return (
self.get_session.query(self.role_model)
.filter_by(name=self.auth_role_public)
.first()
)

def get_public_permissions(self):
role = self.get_public_role()
if role:
return role.permissions
return []
Expand Down Expand Up @@ -314,6 +316,24 @@ def exist_permission_on_roles(
return self.appbuilder.get_session.query(literal(True)).filter(q).scalar()
return self.appbuilder.get_session.query(q).scalar()

def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[int]):
return (
self.appbuilder.get_session.query(self.permissionview_model)
.join(
assoc_permissionview_role,
and_(
(self.permissionview_model.id ==
assoc_permissionview_role.c.permission_view_id),
),
)
.join(self.role_model)
.join(self.permission_model)
.join(self.viewmenu_model)
.filter(
self.permission_model.name == permission_name,
self.role_model.id.in_(role_ids))
).all()

def add_permission(self, name):
"""
Adds a permission to the backend, model permission
Expand Down
Loading