diff --git a/examples/crud_rest_api/app/api.py b/examples/crud_rest_api/app/api.py index bb5076d4b6..afbe49cf3c 100644 --- a/examples/crud_rest_api/app/api.py +++ b/examples/crud_rest_api/app/api.py @@ -1,9 +1,11 @@ from flask_appbuilder import ModelRestApi from flask_appbuilder.api import BaseApi, expose from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_appbuilder.models.filters import BaseFilter +from sqlalchemy import or_ from . import appbuilder, db -from .models import Contact, ContactGroup, Gender, ModelOMChild, ModelOMParent +from .models import Contact, ContactGroup, Gender, ModelOMParent def fill_gender(): @@ -52,11 +54,25 @@ def greeting(self): appbuilder.add_api(GreetingApi) +class CustomFilter(BaseFilter): + name = "Custom Filter" + arg_name = "opr" + + def apply(self, query, value): + return query.filter( + or_( + Contact.name.like(value + "%"), + Contact.address.like(value + "%"), + ) + ) + + class ContactModelApi(ModelRestApi): resource_name = "contact" datamodel = SQLAInterface(Contact) allow_browser_login = True + search_filters = {"name": [CustomFilter]} openapi_spec_methods = { "get_list": { "get": { diff --git a/flask_appbuilder/api/__init__.py b/flask_appbuilder/api/__init__.py index 61eee912a3..20589c4901 100644 --- a/flask_appbuilder/api/__init__.py +++ b/flask_appbuilder/api/__init__.py @@ -750,6 +750,10 @@ class MyView(ModelRestApi): search_columns = ['name', 'address'] """ + search_filters = None + """ + Override default search filters for columns + """ search_exclude_columns = None """ List with columns to exclude from search. Search includes all possible @@ -846,7 +850,6 @@ def _init_properties(self): x for x in search_columns if x not in self.search_exclude_columns ] self._gen_labels_columns(self.datamodel.get_columns_list()) - self._filters = self.datamodel.get_filters(self.search_columns) def _init_titles(self): pass @@ -1114,7 +1117,9 @@ def _init_properties(self): ] self._gen_labels_columns(self.list_columns) self._gen_labels_columns(self.show_columns) - self._filters = self.datamodel.get_filters(self.search_columns) + self._filters = self.datamodel.get_filters( + search_columns=self.search_columns, search_filters=self.search_filters + ) self.edit_query_rel_fields = self.edit_query_rel_fields or dict() self.add_query_rel_fields = self.add_query_rel_fields or dict() diff --git a/flask_appbuilder/models/base.py b/flask_appbuilder/models/base.py index ef90cd838d..1e4c19ec2b 100644 --- a/flask_appbuilder/models/base.py +++ b/flask_appbuilder/models/base.py @@ -83,9 +83,14 @@ def _get_attr_value(item, col): return value.value return value - def get_filters(self, search_columns=None): + def get_filters(self, search_columns=None, search_filters=None): search_columns = search_columns or [] - return Filters(self.filter_converter_class, self, search_columns) + return Filters( + self.filter_converter_class, + self, + search_columns=search_columns, + search_filters=search_filters, + ) def get_values_item(self, item, show_columns): return [self._get_attr_value(item, col) for col in show_columns] diff --git a/flask_appbuilder/models/filters.py b/flask_appbuilder/models/filters.py index 140d18c3c5..36579add5a 100644 --- a/flask_appbuilder/models/filters.py +++ b/flask_appbuilder/models/filters.py @@ -1,5 +1,6 @@ import copy import logging +from typing import Any, Dict, List, Tuple from .._compat import as_unicode from ..exceptions import ( @@ -69,10 +70,14 @@ class FilterRelation(BaseFilter): Base class for all filters for relations """ - pass + def apply(self, query, value): + """ + Override this to implement your own new filters + """ + raise NotImplementedError -class BaseFilterConverter(object): +class BaseFilterConverter: """ Base Filter Converter, all classes responsible for the association of columns and possible filters @@ -113,20 +118,27 @@ def convert(self, col_name): class Filters(object): - filters = [] - """ List of instanciated BaseFilter classes """ - values = [] + filters: List[BaseFilter] = [] + """ List of instantiated BaseFilter classes """ + values: List[Any] = [] """ list of values to apply to filters """ - _search_filters = {} + _search_filters: Dict[str, List[BaseFilter]] = {} """ dict like {'col_name':[BaseFilter1, BaseFilter2, ...], ... } """ - _all_filters = {} - - def __init__(self, filter_converter, datamodel, search_columns=None): + _all_filters: Dict[str, List[BaseFilter]] = {} + + def __init__( + self, + filter_converter: BaseFilterConverter, + datamodel, + search_columns: List[str] = None, + search_filters: Dict[str, List[BaseFilter]] = None, + ): """ :param filter_converter: Accepts BaseFilterConverter class :param search_columns: restricts possible columns, accepts a list of column names + :param search_filters: Add custom defined filters to specific columns :param datamodel: Accepts BaseInterface class """ self.search_columns = search_columns or [] @@ -137,10 +149,14 @@ def __init__(self, filter_converter, datamodel, search_columns=None): self._search_filters = self._get_filters(self.search_columns) self._all_filters = self._get_filters(datamodel.get_columns_list()) + if search_filters: + for k, v in search_filters.items(): + self._search_filters[k] += v + def get_search_filters(self): return self._search_filters - def _get_filters(self, cols): + def _get_filters(self, cols: List[str]): filters = {} for col in cols: _filters = self.filter_converter(self.datamodel).convert(col) @@ -156,10 +172,12 @@ def _add_filter(self, filter_instance, value): self.filters.append(filter_instance) self.values.append(value) - def add_filter_index(self, column_name, filter_instance_index, value): + def add_filter_index( + self, column_name: str, filter_instance_index: int, value: Any + ): self._add_filter(self._all_filters[column_name][filter_instance_index], value) - def rest_add_filters(self, data): + def rest_add_filters(self, data: List[Dict]) -> None: """ Adds list of dicts @@ -174,9 +192,10 @@ def rest_add_filters(self, data): except KeyError: log.warning("Invalid filter") return + # Get filter class from defaults filter_class = map_args_filter.get(opr, None) if filter_class: - if _filter["col"] not in self.search_columns: + if col not in self.search_columns: raise InvalidColumnFilterFABException( f"Filter column: {col} not allowed to filter" ) @@ -184,8 +203,15 @@ def rest_add_filters(self, data): raise InvalidOperationFilterFABException( f"Filter operation: {opr} not allowed on column: {col}" ) - else: - self.add_filter(col, filter_class, value) + self.add_filter(col, filter_class, value) + continue + # Get filter class from custom defined filters + filters = self._search_filters.get(col) + if filters: + for filter in filters: + if filter.arg_name == opr: + self.add_filter(col, filter, value) + break else: raise InvalidOperationFilterFABException( f"Filter operation: {opr} not allowed on column: {col}" @@ -215,10 +241,10 @@ def get_joined_filters(self, filters): """ Creates a new filters class with active filters joined """ - retfilters = Filters(self.filter_converter, self.datamodel) - retfilters.filters = self.filters + filters.filters - retfilters.values = self.values + filters.values - return retfilters + ret_filters = Filters(self.filter_converter, self.datamodel) + ret_filters.filters = self.filters + filters.filters + ret_filters.values = self.values + filters.values + return ret_filters def copy(self): """ @@ -241,13 +267,13 @@ def get_relation_cols(self): retlst.append(flt.column_name) return retlst - def get_filters_values(self): + def get_filters_values(self) -> List[Tuple[BaseFilter, Any]]: """ Returns a list of tuples [(FILTER, value),(...,...),....] """ return [(flt, value) for flt, value in zip(self.filters, self.values)] - def get_filter_value(self, column_name): + def get_filter_value(self, column_name: str) -> Any: """ Returns the filtered value for a certain column @@ -258,7 +284,7 @@ def get_filter_value(self, column_name): if flt.column_name == column_name: return value - def get_filters_values_tojson(self): + def get_filters_values_tojson(self) -> List[Tuple[str, str, Any]]: return [ (flt.column_name, as_unicode(flt.name), value) for flt, value in zip(self.filters, self.values) diff --git a/flask_appbuilder/tests/test_api.py b/flask_appbuilder/tests/test_api.py index 61971ae239..b972fef5c8 100644 --- a/flask_appbuilder/tests/test_api.py +++ b/flask_appbuilder/tests/test_api.py @@ -125,6 +125,7 @@ class APITestCase(FABTestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder + from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.api import ( BaseApi, @@ -182,6 +183,21 @@ class Model1Api(ModelRestApi): self.model1api = Model1Api self.appbuilder.add_api(Model1Api) + class CustomFilter(BaseFilter): + name = "Custom Filter" + arg_name = "custom_filter" + + def apply(self, query, value): + return query.filter( + ~Model1.field_string.like(value + "%"), Model1.field_integer == 1 + ) + + class Model1ApiSearchFilters(ModelRestApi): + datamodel = SQLAInterface(Model1) + search_filters = {"field_string": [CustomFilter]} + + self.appbuilder.add_api(Model1ApiSearchFilters) + class Model1ApiFieldsInfo(Model1Api): datamodel = SQLAInterface(Model1) add_columns = ["field_integer", "field_float", "field_string", "field_date"] @@ -1275,9 +1291,124 @@ def test_get_list_filters_wrong_order(self): rv = self.auth_client_get(client, token, uri) self.assertEqual(rv.status_code, 400) + def test_get_list_multiple_search_filters(self): + """ + REST Api: Test get list multiple search filters + """ + session = self.appbuilder.get_session + model1_1 = Model1(field_string="abc", field_integer=6) + session.add(model1_1) + session.commit() + + arguments = { + API_FILTERS_RIS_KEY: [ + {"col": "field_integer", "opr": "gt", "value": 5}, + {"col": "field_integer", "opr": "lt", "value": 7}, + ] + } + rison_args = prison.dumps(arguments) + uri = f"api/v1/model1apisearchfilters/?{API_URI_RIS_KEY}={rison_args}" + + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 2) + + arguments = { + API_FILTERS_RIS_KEY: [ + {"col": "field_integer", "opr": "gt", "value": 5}, + {"col": "field_integer", "opr": "lt", "value": 7}, + {"col": "field_string", "opr": "sw", "value": "a"}, + ] + } + rison_args = prison.dumps(arguments) + uri = f"api/v1/model1apisearchfilters/?{API_URI_RIS_KEY}={rison_args}" + + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 1) + self.assertEqual(data["result"][0]["field_string"], "abc") + + session.delete(model1_1) + session.commit() + + def test_get_list_custom_search_filters(self): + """ + REST Api: Test get list custom filters + """ + session = self.appbuilder.get_session + model1_1 = Model1(field_string="abc", field_integer=2) + # Custom filter will get this next model (not like 'test' and field_integer=1) + model1_2 = Model1(field_string="abcd", field_integer=1) + session.add(model1_1) + session.add(model1_2) + session.commit() + + filter_value = "test" + arguments = { + API_FILTERS_RIS_KEY: [ + {"col": "field_string", "opr": "custom_filter", "value": filter_value} + ] + } + rison_args = prison.dumps(arguments) + uri = f"api/v1/model1apisearchfilters/?{API_URI_RIS_KEY}={rison_args}" + + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 1) + expected_result = [ + { + "field_date": None, + "field_float": None, + "field_integer": 1, + "field_string": "abcd", + } + ] + self.assertEqual(data[API_RESULT_RES_KEY], expected_result) + + arguments = { + API_FILTERS_RIS_KEY: [ + {"col": "field_string", "opr": "custom_filter", "value": filter_value}, + {"col": "field_integer", "opr": "eq", "value": 3}, + ] + } + rison_args = prison.dumps(arguments) + uri = f"api/v1/model1apisearchfilters/?{API_URI_RIS_KEY}={rison_args}" + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 0) + session.delete(model1_1) + session.delete(model1_2) + session.commit() + + def test_get_info_custom_search_filters(self): + """ + REST Api: Test get info custom filters + """ + arguments = {"keys": ["filters"]} + rison_args = prison.dumps(arguments) + uri = f"api/v1/model1apisearchfilters/_info?{API_URI_RIS_KEY}={rison_args}" + + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + field_string_filters = data["filters"]["field_string"] + self.assertIn( + {"name": "Custom Filter", "operator": "custom_filter"}, field_string_filters + ) + def test_get_list_select_cols(self): """ - REST Api: Test get list with selected columns + REST Api: Test get list with select columns """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)