diff --git a/requirements.txt b/requirements.txt index a3f336b610ed..578b442e6701 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,7 +52,7 @@ marshmallow==2.19.5 # via flask-appbuilder, marshmallow-enum, marshmallow- more-itertools==8.1.0 # via zipp msgpack==0.6.2 # via apache-superset (setup.py) numpy==1.18.1 # via pandas, pyarrow -pandas==0.25.3 # via apache-superset (setup.py) +pandas==1.0.3 # via apache-superset (setup.py) parsedatetime==2.5 # via apache-superset (setup.py) pathlib2==2.3.5 # via apache-superset (setup.py) polyline==1.4.0 # via apache-superset (setup.py) diff --git a/setup.py b/setup.py index eb225d39cfe7..ad8f4b92e912 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def get_git_sha(): "isodate", "markdown>=3.0", "msgpack>=0.6.1, <0.7.0", - "pandas>=0.25.3, <1.0", + "pandas>=1.0.3, <1.1", "parsedatetime", "pathlib2", "polyline", diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 6377fb19767e..656988508db8 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -51,7 +51,7 @@ class QueryContext: custom_cache_timeout: Optional[int] # TODO: Type datasource and query_object dictionary with TypedDict when it becomes - # a vanilla python type https://github.com/python/mypy/issues/5288 + # a vanilla python type https://github.com/python/mypy/issues/5288 def __init__( self, datasource: Dict[str, Any], @@ -70,8 +70,8 @@ def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: """Returns a pandas dataframe based on the query object""" # Here, we assume that all the queries will use the same datasource, which is - # is a valid assumption for current setting. In a long term, we may or maynot - # support multiple queries from different data source. + # a valid assumption for current setting. In the long term, we may + # support multiple queries from different data sources. timestamp_format = None if self.datasource.type == "table": @@ -105,6 +105,9 @@ def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: self.df_metrics_to_num(df, query_object) df.replace([np.inf, -np.inf], np.nan) + + df = query_object.exec_post_processing(df) + return { "query": result.query, "status": result.status, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index f5133857583f..72e9dfef8f51 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -20,13 +20,16 @@ from typing import Any, Dict, List, Optional, Union import simplejson as json +from flask_babel import gettext as _ +from pandas import DataFrame from superset import app -from superset.utils import core as utils +from superset.exceptions import QueryObjectValidationError +from superset.utils import core as utils, pandas_postprocessing from superset.views.utils import get_time_range_endpoints # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type -# https://github.com/python/mypy/issues/5288 +# https://github.com/python/mypy/issues/5288 class QueryObject: @@ -50,6 +53,7 @@ class QueryObject: extras: Dict columns: List[str] orderby: List[List] + post_processing: List[Dict[str, Any]] def __init__( self, @@ -67,6 +71,7 @@ def __init__( extras: Optional[Dict] = None, columns: Optional[List[str]] = None, orderby: Optional[List[List]] = None, + post_processing: Optional[List[Dict[str, Any]]] = None, relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"], relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"], ): @@ -81,8 +86,9 @@ def __init__( self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) self.groupby = groupby or [] + self.post_processing = post_processing or [] - # Temporal solution for backward compatability issue due the new format of + # Temporary solution for backward compatibility issue due the new format of # non-ad-hoc metric which needs to adhere to superset-ui per # https://git.io/Jvm7P. self.metrics = [ @@ -138,9 +144,37 @@ def cache_key(self, **extra: Any) -> str: if self.time_range: cache_dict["time_range"] = self.time_range json_data = self.json_dumps(cache_dict, sort_keys=True) + if self.post_processing: + cache_dict["post_processing"] = self.post_processing return hashlib.md5(json_data.encode("utf-8")).hexdigest() def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: return json.dumps( obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys ) + + def exec_post_processing(self, df: DataFrame) -> DataFrame: + """ + Perform post processing operations on DataFrame. + + :param df: DataFrame returned from database model. + :return: new DataFrame to which all post processing operations have been + applied + :raises ChartDataValidationError: If the post processing operation in incorrect + """ + for post_process in self.post_processing: + operation = post_process.get("operation") + if not operation: + raise QueryObjectValidationError( + _("`operation` property of post processing object undefined") + ) + if not hasattr(pandas_postprocessing, operation): + raise QueryObjectValidationError( + _( + "Unsupported post processing operation: %(operation)s", + type=operation, + ) + ) + options = post_process.get("options", {}) + df = getattr(pandas_postprocessing, operation)(df, **options) + return df diff --git a/superset/exceptions.py b/superset/exceptions.py index 7564decb232c..e7f2e2d400ef 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -68,3 +68,7 @@ class CertificateException(SupersetException): class DatabaseNotFound(SupersetException): status = 400 + + +class QueryObjectValidationError(SupersetException): + status = 400 diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py new file mode 100644 index 000000000000..2800ee178e87 --- /dev/null +++ b/superset/utils/pandas_postprocessing.py @@ -0,0 +1,389 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +from flask_babel import gettext as _ +from pandas import DataFrame, NamedAgg + +from superset.exceptions import QueryObjectValidationError + +WHITELIST_NUMPY_FUNCTIONS = ( + "average", + "argmin", + "argmax", + "cumsum", + "cumprod", + "max", + "mean", + "median", + "nansum", + "nanmin", + "nanmax", + "nanmean", + "nanmedian", + "min", + "percentile", + "prod", + "product", + "std", + "sum", + "var", +) + +WHITELIST_ROLLING_FUNCTIONS = ( + "count", + "corr", + "cov", + "kurt", + "max", + "mean", + "median", + "min", + "std", + "skew", + "sum", + "var", + "quantile", +) + +WHITELIST_CUMULATIVE_FUNCTIONS = ( + "cummax", + "cummin", + "cumprod", + "cumsum", +) + + +def validate_column_args(*argnames: str) -> Callable: + def wrapper(func): + def wrapped(df, **options): + columns = df.columns.tolist() + for name in argnames: + if name in options and not all( + elem in columns for elem in options[name] + ): + raise QueryObjectValidationError( + _("Referenced columns not available in DataFrame.") + ) + return func(df, **options) + + return wrapped + + return wrapper + + +def _get_aggregate_funcs( + df: DataFrame, aggregates: Dict[str, Dict[str, Any]], +) -> Dict[str, NamedAgg]: + """ + Converts a set of aggregate config objects into functions that pandas can use as + aggregators. Currently only numpy aggregators are supported. + + :param df: DataFrame on which to perform aggregate operation. + :param aggregates: Mapping from column name to aggregat config. + :return: Mapping from metric name to function that takes a single input argument. + """ + agg_funcs: Dict[str, NamedAgg] = {} + for name, agg_obj in aggregates.items(): + column = agg_obj.get("column", name) + if column not in df: + raise QueryObjectValidationError( + _( + "Column referenced by aggregate is undefined: %(column)s", + column=column, + ) + ) + if "operator" not in agg_obj: + raise QueryObjectValidationError( + _("Operator undefined for aggregator: %(name)s", name=name,) + ) + operator = agg_obj["operator"] + if operator not in WHITELIST_NUMPY_FUNCTIONS or not hasattr(np, operator): + raise QueryObjectValidationError( + _("Invalid numpy function: %(operator)s", operator=operator,) + ) + func = getattr(np, operator) + options = agg_obj.get("options", {}) + agg_funcs[name] = NamedAgg(column=column, aggfunc=partial(func, **options)) + + return agg_funcs + + +def _append_columns( + base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str] +) -> DataFrame: + """ + Function for adding columns from one DataFrame to another DataFrame. Calls the + assign method, which overwrites the original column in `base_df` if the column + already exists, and appends the column if the name is not defined. + + :param base_df: DataFrame which to use as the base + :param append_df: DataFrame from which to select data. + :param columns: columns on which to append, mapping source column to + target column. For instance, `{'y': 'y'}` will replace the values in + column `y` in `base_df` with the values in `y` in `append_df`, + while `{'y': 'y2'}` will add a column `y2` to `base_df` based + on values in column `y` in `append_df`, leaving the original column `y` + in `base_df` unchanged. + :return: new DataFrame with combined data from `base_df` and `append_df` + """ + return base_df.assign( + **{ + target: append_df[append_df.columns[idx]] + for idx, target in enumerate(columns.values()) + } + ) + + +@validate_column_args("index", "columns") +def pivot( # pylint: disable=too-many-arguments + df: DataFrame, + index: List[str], + columns: List[str], + aggregates: Dict[str, Dict[str, Any]], + metric_fill_value: Optional[Any] = None, + column_fill_value: Optional[str] = None, + drop_missing_columns: Optional[bool] = True, + combine_value_with_metric=False, + marginal_distributions: Optional[bool] = None, + marginal_distribution_name: Optional[str] = None, +) -> DataFrame: + """ + Perform a pivot operation on a DataFrame. + + :param df: Object on which pivot operation will be performed + :param index: Columns to group by on the table index (=rows) + :param columns: Columns to group by on the table columns + :param metric_fill_value: Value to replace missing values with + :param column_fill_value: Value to replace missing pivot columns with + :param drop_missing_columns: Do not include columns whose entries are all missing + :param combine_value_with_metric: Display metrics side by side within each column, + as opposed to each column being displayed side by side for each metric. + :param aggregates: A mapping from aggregate column name to the the aggregate + config. + :param marginal_distributions: Add totals for row/column. Default to False + :param marginal_distribution_name: Name of row/column with marginal distribution. + Default to 'All'. + :return: A pivot table + :raises ChartDataValidationError: If the request in incorrect + """ + if not index: + raise QueryObjectValidationError( + _("Pivot operation requires at least one index") + ) + if not columns: + raise QueryObjectValidationError( + _("Pivot operation requires at least one column") + ) + if not aggregates: + raise QueryObjectValidationError( + _("Pivot operation must include at least one aggregate") + ) + + if column_fill_value: + df[columns] = df[columns].fillna(value=column_fill_value) + + aggregate_funcs = _get_aggregate_funcs(df, aggregates) + + # TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table. + # Remove once/if support is added. + aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()} + + df = df.pivot_table( + values=aggfunc.keys(), + index=index, + columns=columns, + aggfunc=aggfunc, + fill_value=metric_fill_value, + dropna=drop_missing_columns, + margins=marginal_distributions, + margins_name=marginal_distribution_name, + ) + + if combine_value_with_metric: + df = df.stack(0).unstack() + + return df + + +@validate_column_args("groupby") +def aggregate( + df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]] +) -> DataFrame: + """ + Apply aggregations to a DataFrame. + + :param df: Object to aggregate. + :param groupby: columns to aggregate + :param aggregates: A mapping from metric column to the function used to + aggregate values. + :raises ChartDataValidationError: If the request in incorrect + """ + aggregates = aggregates or {} + aggregate_funcs = _get_aggregate_funcs(df, aggregates) + return df.groupby(by=groupby).agg(**aggregate_funcs).reset_index() + + +@validate_column_args("columns") +def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: + """ + Sort a DataFrame. + + :param df: DataFrame to sort. + :param columns: columns by by which to sort. The key specifies the column name, + value specifies if sorting in ascending order. + :return: Sorted DataFrame + :raises ChartDataValidationError: If the request in incorrect + """ + return df.sort_values(by=list(columns.keys()), ascending=list(columns.values())) + + +@validate_column_args("columns") +def rolling( # pylint: disable=too-many-arguments + df: DataFrame, + columns: Dict[str, str], + rolling_type: str, + window: int, + rolling_type_options: Optional[Dict[str, Any]] = None, + center: bool = False, + win_type: Optional[str] = None, + min_periods: Optional[int] = None, +) -> DataFrame: + """ + Apply a rolling window on the dataset. See the Pandas docs for further details: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html + + :param df: DataFrame on which the rolling period will be based. + :param columns: columns on which to perform rolling, mapping source column to + target column. For instance, `{'y': 'y'}` will replace the column `y` with + the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based + on rolling values calculated from `y`, leaving the original column `y` + unchanged. + :param rolling_type: Type of rolling window. Any numpy function will work. + :param rolling_type_options: Optional options to pass to rolling method. Needed + for e.g. quantile operation. + :param center: Should the label be at the center of the window. + :param win_type: Type of window function. + :param window: Size of the window. + :param min_periods: + :return: DataFrame with the rolling columns + :raises ChartDataValidationError: If the request in incorrect + """ + rolling_type_options = rolling_type_options or {} + df_rolling = df[columns.keys()] + kwargs: Dict[str, Union[str, int]] = {} + if not window: + raise QueryObjectValidationError(_("Undefined window for rolling operation")) + + kwargs["window"] = window + if min_periods is not None: + kwargs["min_periods"] = min_periods + if center is not None: + kwargs["center"] = center + if win_type is not None: + kwargs["win_type"] = win_type + + df_rolling = df_rolling.rolling(**kwargs) + if rolling_type not in WHITELIST_ROLLING_FUNCTIONS or not hasattr( + df_rolling, rolling_type + ): + raise QueryObjectValidationError( + _("Invalid rolling_type: %(type)s", type=rolling_type) + ) + try: + df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options) + except TypeError: + raise QueryObjectValidationError( + _( + "Invalid options for %(rolling_type)s: %(options)s", + rolling_type=rolling_type, + options=rolling_type_options, + ) + ) + df = _append_columns(df, df_rolling, columns) + if min_periods: + df = df[min_periods:] + return df + + +@validate_column_args("columns", "rename") +def select( + df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None +) -> DataFrame: + """ + Only select a subset of columns in the original dataset. Can be useful for + removing unnecessary intermediate results, renaming and reordering columns. + + :param df: DataFrame on which the rolling period will be based. + :param columns: Columns which to select from the DataFrame, in the desired order. + If columns are renamed, the new column name should be referenced + here. + :param rename: columns which to rename, mapping source column to target column. + For instance, `{'y': 'y2'}` will rename the column `y` to + `y2`. + :return: Subset of columns in original DataFrame + :raises ChartDataValidationError: If the request in incorrect + """ + df_select = df[columns] + if rename is not None: + df_select = df_select.rename(columns=rename) + return df_select + + +@validate_column_args("columns") +def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame: + """ + + :param df: DataFrame on which the diff will be based. + :param columns: columns on which to perform diff, mapping source column to + target column. For instance, `{'y': 'y'}` will replace the column `y` with + the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based + on diff values calculated from `y`, leaving the original column `y` + unchanged. + :param periods: periods to shift for calculating difference. + :return: DataFrame with diffed columns + :raises ChartDataValidationError: If the request in incorrect + """ + df_diff = df[columns.keys()] + df_diff = df_diff.diff(periods=periods) + return _append_columns(df, df_diff, columns) + + +@validate_column_args("columns") +def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: + """ + + :param df: DataFrame on which the cumulative operation will be based. + :param columns: columns on which to perform a cumulative operation, mapping source + column to target column. For instance, `{'y': 'y'}` will replace the column + `y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column + `y2` based on cumulative values calculated from `y`, leaving the original + column `y` unchanged. + :param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max` + :return: + """ + df_cum = df[columns.keys()] + operation = "cum" + operator + if operation not in WHITELIST_CUMULATIVE_FUNCTIONS or not hasattr( + df_cum, operation + ): + raise QueryObjectValidationError( + _("Invalid cumulative operator: %(operator)s", operator=operator) + ) + return _append_columns(df, getattr(df_cum, operation)(), columns) diff --git a/tests/core_tests.py b/tests/core_tests.py index eb3e2f7ba84e..73711d9301d3 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -111,7 +111,7 @@ def test_slice_endpoint(self): resp = self.client.get("/superset/slice/-1/") assert resp.status_code == 404 - def _get_query_context_dict(self) -> Dict[str, Any]: + def _get_query_context(self) -> Dict[str, Any]: self.login(username="admin") slc = self.get_slice("Girl Name Cloud", db.session) return { @@ -127,6 +127,45 @@ def _get_query_context_dict(self) -> Dict[str, Any]: ], } + def _get_query_context_with_post_processing(self) -> Dict[str, Any]: + self.login(username="admin") + slc = self.get_slice("Girl Name Cloud", db.session) + return { + "datasource": {"id": slc.datasource_id, "type": slc.datasource_type}, + "queries": [ + { + "granularity": "ds", + "groupby": ["name", "state"], + "metrics": [{"label": "sum__num"}], + "filters": [], + "row_limit": 100, + "post_processing": [ + { + "operation": "aggregate", + "options": { + "groupby": ["state"], + "aggregates": { + "q1": { + "operator": "percentile", + "column": "sum__num", + "options": {"q": 25}, + }, + "median": { + "operator": "median", + "column": "sum__num", + }, + }, + }, + }, + { + "operation": "sort", + "options": {"columns": {"q1": False, "state": True},}, + }, + ], + } + ], + } + def test_viz_cache_key(self): self.login(username="admin") slc = self.get_slice("Girls", db.session) @@ -140,7 +179,7 @@ def test_viz_cache_key(self): self.assertNotEqual(cache_key, viz.cache_key(qobj)) def test_cache_key_changes_when_datasource_is_updated(self): - qc_dict = self._get_query_context_dict() + qc_dict = self._get_query_context() # construct baseline cache_key query_context = QueryContext(**qc_dict) @@ -168,7 +207,7 @@ def test_cache_key_changes_when_datasource_is_updated(self): self.assertNotEqual(cache_key_original, cache_key_new) def test_query_context_time_range_endpoints(self): - query_context = QueryContext(**self._get_query_context_dict()) + query_context = QueryContext(**self._get_query_context()) query_object = query_context.queries[0] extras = query_object.to_dict()["extras"] self.assertTrue("time_range_endpoints" in extras) @@ -217,11 +256,18 @@ def test_get_superset_tables_not_found(self): def test_api_v1_query_endpoint(self): self.login(username="admin") - qc_dict = self._get_query_context_dict() + qc_dict = self._get_query_context() data = json.dumps(qc_dict) resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data})) self.assertEqual(resp[0]["rowcount"], 100) + def test_api_v1_query_endpoint_with_post_processing(self): + self.login(username="admin") + qc_dict = self._get_query_context_with_post_processing() + data = json.dumps(qc_dict) + resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data})) + self.assertEqual(resp[0]["rowcount"], 6) + def test_old_slice_json_endpoint(self): self.login(username="admin") slc = self.get_slice("Girls", db.session) diff --git a/tests/fixtures/dataframes.py b/tests/fixtures/dataframes.py new file mode 100644 index 000000000000..e565dc40a002 --- /dev/null +++ b/tests/fixtures/dataframes.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import date + +from pandas import DataFrame, to_datetime + +names_df = DataFrame( + [ + { + "dt": date(2020, 1, 2), + "name": "John", + "country": "United Kingdom", + "cars": 3, + "bikes": 1, + "seconds": 30, + }, + { + "dt": date(2020, 1, 2), + "name": "Peter", + "country": "Sweden", + "cars": 4, + "bikes": 2, + "seconds": 1, + }, + { + "dt": date(2020, 1, 3), + "name": "Mary", + "country": "Finland", + "cars": 5, + "bikes": 3, + "seconds": None, + }, + { + "dt": date(2020, 1, 3), + "name": "Peter", + "country": "India", + "cars": 6, + "bikes": 4, + "seconds": 12, + }, + { + "dt": date(2020, 1, 4), + "name": "John", + "country": "Portugal", + "cars": 7, + "bikes": None, + "seconds": 75, + }, + { + "dt": date(2020, 1, 4), + "name": "Peter", + "country": "Italy", + "cars": None, + "bikes": 5, + "seconds": 600, + }, + { + "dt": date(2020, 1, 4), + "name": "Mary", + "country": None, + "cars": 9, + "bikes": 6, + "seconds": 2, + }, + { + "dt": date(2020, 1, 4), + "name": None, + "country": "Australia", + "cars": 10, + "bikes": 7, + "seconds": 99, + }, + { + "dt": date(2020, 1, 1), + "name": "John", + "country": "USA", + "cars": 1, + "bikes": 8, + "seconds": None, + }, + { + "dt": date(2020, 1, 1), + "name": "Mary", + "country": "Fiji", + "cars": 2, + "bikes": 9, + "seconds": 50, + }, + ] +) + +categories_df = DataFrame( + { + "constant": ["dummy" for _ in range(0, 101)], + "category": [f"cat{i%3}" for i in range(0, 101)], + "dept": [f"dept{i%5}" for i in range(0, 101)], + "name": [f"person{i}" for i in range(0, 101)], + "asc_idx": [i for i in range(0, 101)], + "desc_idx": [i for i in range(100, -1, -1)], + "idx_nulls": [i if i % 5 == 0 else None for i in range(0, 101)], + } +) + +timeseries_df = DataFrame( + index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]), + data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]}, +) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py new file mode 100644 index 000000000000..a981477d0bb4 --- /dev/null +++ b/tests/pandas_postprocessing_tests.py @@ -0,0 +1,290 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +import math +from typing import Any, List + +from pandas import Series + +from superset.exceptions import QueryObjectValidationError +from superset.utils import pandas_postprocessing as proc + +from .base_tests import SupersetTestCase +from .fixtures.dataframes import categories_df, timeseries_df + + +def series_to_list(series: Series) -> List[Any]: + """ + Converts a `Series` to a regular list, and replaces non-numeric values to + Nones. + + :param series: Series to convert + :return: list without nan or inf + """ + return [ + None + if not isinstance(val, str) and (math.isnan(val) or math.isinf(val)) + else val + for val in series.tolist() + ] + + +class PostProcessingTestCase(SupersetTestCase): + def test_pivot(self): + aggregates = {"idx_nulls": {"operator": "sum"}} + + # regular pivot + df = proc.pivot( + df=categories_df, + index=["name"], + columns=["category"], + aggregates=aggregates, + ) + self.assertListEqual( + df.columns.tolist(), + [("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")], + ) + self.assertEqual(len(df), 101) + self.assertEqual(df.sum()[0], 315) + + # regular pivot + df = proc.pivot( + df=categories_df, + index=["dept"], + columns=["category"], + aggregates=aggregates, + ) + self.assertEqual(len(df), 5) + + # fill value + df = proc.pivot( + df=categories_df, + index=["name"], + columns=["category"], + metric_fill_value=1, + aggregates={"idx_nulls": {"operator": "sum"}}, + ) + self.assertEqual(df.sum()[0], 382) + + # invalid index reference + self.assertRaises( + QueryObjectValidationError, + proc.pivot, + df=categories_df, + index=["abc"], + columns=["dept"], + aggregates=aggregates, + ) + + # invalid column reference + self.assertRaises( + QueryObjectValidationError, + proc.pivot, + df=categories_df, + index=["dept"], + columns=["abc"], + aggregates=aggregates, + ) + + # invalid aggregate options + self.assertRaises( + QueryObjectValidationError, + proc.pivot, + df=categories_df, + index=["name"], + columns=["category"], + aggregates={"idx_nulls": {}}, + ) + + def test_aggregate(self): + aggregates = { + "asc sum": {"column": "asc_idx", "operator": "sum"}, + "asc q2": { + "column": "asc_idx", + "operator": "percentile", + "options": {"q": 75}, + }, + "desc q1": { + "column": "desc_idx", + "operator": "percentile", + "options": {"q": 25}, + }, + } + df = proc.aggregate( + df=categories_df, groupby=["constant"], aggregates=aggregates + ) + self.assertListEqual( + df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"] + ) + self.assertEqual(series_to_list(df["asc sum"])[0], 5050) + self.assertEqual(series_to_list(df["asc q2"])[0], 75) + self.assertEqual(series_to_list(df["desc q1"])[0], 25) + + def test_sort(self): + df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False}) + self.assertEqual(96, series_to_list(df["asc_idx"])[1]) + + self.assertRaises( + QueryObjectValidationError, proc.sort, df=df, columns={"abc": True} + ) + + def test_rolling(self): + # sum rolling type + post_df = proc.rolling( + df=timeseries_df, + columns={"y": "y"}, + rolling_type="sum", + window=2, + min_periods=0, + ) + + self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) + self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0]) + + # mean rolling type with alias + post_df = proc.rolling( + df=timeseries_df, + rolling_type="mean", + columns={"y": "y_mean"}, + window=10, + min_periods=0, + ) + self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"]) + self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5]) + + # count rolling type + post_df = proc.rolling( + df=timeseries_df, + rolling_type="count", + columns={"y": "y"}, + window=10, + min_periods=0, + ) + self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) + self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0]) + + # quantile rolling type + post_df = proc.rolling( + df=timeseries_df, + columns={"y": "q1"}, + rolling_type="quantile", + rolling_type_options={"quantile": 0.25}, + window=10, + min_periods=0, + ) + self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"]) + self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75]) + + # incorrect rolling type + self.assertRaises( + QueryObjectValidationError, + proc.rolling, + df=timeseries_df, + columns={"y": "y"}, + rolling_type="abc", + window=2, + ) + + # incorrect rolling type options + self.assertRaises( + QueryObjectValidationError, + proc.rolling, + df=timeseries_df, + columns={"y": "y"}, + rolling_type="quantile", + rolling_type_options={"abc": 123}, + window=2, + ) + + def test_select(self): + # reorder columns + post_df = proc.select(df=timeseries_df, columns=["y", "label"]) + self.assertListEqual(post_df.columns.tolist(), ["y", "label"]) + + # one column + post_df = proc.select(df=timeseries_df, columns=["label"]) + self.assertListEqual(post_df.columns.tolist(), ["label"]) + + # rename one column + post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"}) + self.assertListEqual(post_df.columns.tolist(), ["y1"]) + + # rename one and leave one unchanged + post_df = proc.select( + df=timeseries_df, columns=["label", "y"], rename={"y": "y1"} + ) + self.assertListEqual(post_df.columns.tolist(), ["label", "y1"]) + + # invalid columns + self.assertRaises( + QueryObjectValidationError, + proc.select, + df=timeseries_df, + columns=["qwerty"], + rename={"abc": "qwerty"}, + ) + + def test_diff(self): + # overwrite column + post_df = proc.diff(df=timeseries_df, columns={"y": "y"}) + self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) + self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0]) + + # add column + post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}) + self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"]) + self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0]) + self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0]) + + # look ahead + post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1) + self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None]) + + # invalid column reference + self.assertRaises( + QueryObjectValidationError, + proc.diff, + df=timeseries_df, + columns={"abc": "abc"}, + ) + + def test_cum(self): + # create new column (cumsum) + post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",) + self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"]) + self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"]) + self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0]) + self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0]) + + # overwrite column (cumprod) + post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",) + self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) + self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0]) + + # overwrite column (cummin) + post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",) + self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) + self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0]) + + # invalid operator + self.assertRaises( + QueryObjectValidationError, + proc.cum, + df=timeseries_df, + columns={"y": "y"}, + operator="abc", + )