From e32fb691cfc7d35bc71c14fc9cfa036dfce89051 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 20 Mar 2020 16:23:28 +0200 Subject: [PATCH 01/14] Deprecate groupby from query_obj --- superset/common/query_object.py | 12 ++- superset/connectors/sqla/models.py | 15 ++- superset/viz.py | 163 +++++++++++------------------ 3 files changed, 77 insertions(+), 113 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 72e9dfef8f51..2cb8728af77d 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=R import hashlib +import logging from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Union @@ -28,6 +29,8 @@ from superset.utils import core as utils, pandas_postprocessing from superset.views.utils import get_time_range_endpoints +logger = logging.getLogger(__name__) + # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type # https://github.com/python/mypy/issues/5288 @@ -85,7 +88,6 @@ def __init__( self.is_timeseries = is_timeseries self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) - self.groupby = groupby or [] self.post_processing = post_processing or [] # Temporary solution for backward compatibility issue due the new format of @@ -106,6 +108,14 @@ def __init__( if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras: self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={}) + self.columns = columns + if groupby: + self.columns += groupby + logger.warning( + f"The field groupby is deprecated. Viz plugins should " + f"pass all selectables via the columns field" + ) + self.columns = columns or [] self.orderby = orderby or [] diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 642997a10554..93db4b50f4ce 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -696,7 +696,6 @@ def _get_sqla_row_level_filters(self, template_processor) -> List[str]: def get_sqla_query( # sqla self, - groupby, metrics, granularity, from_dttm, @@ -716,7 +715,6 @@ def get_sqla_query( # sqla """Querying any sqla table from this common interface""" template_kwargs = { "from_dttm": from_dttm, - "groupby": groupby, "metrics": metrics, "row_limit": row_limit, "to_dttm": to_dttm, @@ -749,7 +747,7 @@ def get_sqla_query( # sqla "and is required by this type of chart" ) ) - if not groupby and not metrics and not columns: + if not metrics and not columns: raise Exception(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] for m in metrics: @@ -768,9 +766,9 @@ def get_sqla_query( # sqla select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() - if groupby: + if metrics and columns: # dedup columns while preserving order - groupby = list(dict.fromkeys(groupby)) + groupby = list(dict.fromkeys(columns)) select_exprs = [] for s in groupby: @@ -829,7 +827,7 @@ def get_sqla_query( # sqla tbl = self.get_from_clause(template_processor) - if not columns: + if metrics: qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] @@ -892,7 +890,7 @@ def get_sqla_query( # sqla qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) - if not orderby and not columns: + if not orderby and metrics: orderby = [(main_metric_expr, not order_desc)] # To ensure correct handling of the ORDER BY labeling we need to reference the @@ -914,7 +912,7 @@ def get_sqla_query( # sqla if row_limit: qry = qry.limit(row_limit) - if is_timeseries and timeseries_limit and groupby and not time_groupby_inline: + if is_timeseries and timeseries_limit and columns and not time_groupby_inline: if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, @@ -972,7 +970,6 @@ def get_sqla_query( # sqla prequery_obj = { "is_timeseries": False, "row_limit": timeseries_limit, - "groupby": groupby, "metrics": metrics, "granularity": granularity, "from_dttm": inner_from_dttm or from_dttm, diff --git a/superset/viz.py b/superset/viz.py index 848e3b35c9e4..de71c53172e3 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -78,6 +78,22 @@ "size", ] +COLUMN_FORM_DATA_PARAMS = [ + "all_columns", + "all_columns_x", + "all_columns_y", + "columns", + "dimension", + "entity", + "groupby", + "order_by_cols", + "series", + "line_column", + "js_columns", +] + +SPATIAL_COLUMN_FORM_DATA_PARAMS = ["spatial", "start_spatial", "end_spatial"] + class BaseViz: @@ -107,7 +123,34 @@ def __init__( self.query = "" self.token = self.form_data.get("token", "token_" + uuid.uuid4().hex[:8]) - self.groupby = self.form_data.get("groupby") or [] + # merge all selectable columns into `columns` property + self.columns: List[str] = [] + for key in COLUMN_FORM_DATA_PARAMS: + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + value = self.form_data.get(key) or [] + value_list = value if isinstance(value, list) else [value] + self.columns += value_list + + for key in SPATIAL_COLUMN_FORM_DATA_PARAMS: + spatial = self.form_data.get(key) + if not isinstance(spatial, dict): + continue + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + if spatial.get("type") == "latlong": + self.columns += [spatial["lonCol"], spatial["latCol"]] + elif spatial.get("type") == "delimited": + self.columns.append(spatial["lonlatCol"]) + elif spatial.get("type") == "geohash": + self.columns.append(spatial["geohashCol"]) + self.time_shift = timedelta() self.status: Optional[str] = None @@ -204,7 +247,6 @@ def get_samples(self): query_obj = self.query_obj() query_obj.update( { - "groupby": [], "metrics": [], "row_limit": 1000, "columns": [o.column_name for o in self.datasource.columns], @@ -291,14 +333,12 @@ def query_obj(self) -> Dict[str, Any]: """Building a query object""" form_data = self.form_data self.process_query_filters() - gb = form_data.get("groupby") or [] metrics = self.all_metrics or [] - columns = form_data.get("columns") or [] - groupby = list(set(gb + columns)) + columns = self.columns is_timeseries = self.is_timeseries - if DTTM_ALIAS in groupby: - groupby.remove(DTTM_ALIAS) + if DTTM_ALIAS in columns: + columns.remove(DTTM_ALIAS) is_timeseries = True granularity = form_data.get("granularity") or form_data.get("granularity_sqla") @@ -342,7 +382,7 @@ def query_obj(self) -> Dict[str, Any]: "from_dttm": from_dttm, "to_dttm": to_dttm, "is_timeseries": is_timeseries, - "groupby": groupby, + "columns": columns, "metrics": metrics, "row_limit": row_limit, "filter": self.form_data.get("filters", []), @@ -567,8 +607,6 @@ def query_obj(self): sort_by = fd.get("timeseries_limit_metric") if fd.get("all_columns"): - d["columns"] = fd.get("all_columns") - d["groupby"] = [] order_by_cols = fd.get("order_by_cols") or [] d["orderby"] = [json.loads(t) for t in order_by_cols] elif sort_by: @@ -816,11 +854,6 @@ class WordCloudViz(BaseViz): verbose_name = _("Word Cloud") is_timeseries = False - def query_obj(self): - d = super().query_obj() - d["groupby"] = [self.form_data.get("series")] - return d - class TreemapViz(BaseViz): @@ -1024,12 +1057,6 @@ class BubbleViz(NVD3Viz): def query_obj(self): form_data = self.form_data d = super().query_obj() - d["groupby"] = [form_data.get("entity")] - if form_data.get("series"): - d["groupby"].append(form_data.get("series")) - - # dedup groupby if it happens to be the same - d["groupby"] = list(dict.fromkeys(d["groupby"])) self.x_metric = form_data.get("x") self.y_metric = form_data.get("y") @@ -1240,7 +1267,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: if aggregate: df = df.pivot_table( index=DTTM_ALIAS, - columns=fd.get("groupby"), + columns=self.columns, values=self.metric_labels, fill_value=0, aggfunc=sum, @@ -1248,7 +1275,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: else: df = df.pivot_table( index=DTTM_ALIAS, - columns=fd.get("groupby"), + columns=self.columns, values=self.metric_labels, fill_value=self.pivot_fill_value, ) @@ -1541,7 +1568,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None metric = self.metric_labels[0] - df = df.pivot_table(index=self.groupby, values=[metric]) + df = df.pivot_table(index=self.columns, values=[metric]) df.sort_values(by=metric, ascending=False, inplace=True) df = df.reset_index() df.columns = ["x", "y"] @@ -1560,13 +1587,8 @@ def query_obj(self): """Returns the query object for this visualization""" d = super().query_obj() d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) - numeric_columns = self.form_data.get("all_columns_x") - if numeric_columns is None: + if not self.form_data.get("all_columns_x"): raise Exception(_("Must have at least one numeric column specified")) - self.columns = numeric_columns - d["columns"] = numeric_columns + self.groupby - # override groupby entry to avoid aggregation - d["groupby"] = [] return d def labelify(self, keys, column): @@ -1613,13 +1635,9 @@ class DistributionBarViz(DistributionPieViz): def query_obj(self): d = super().query_obj() fd = self.form_data - if len(d["groupby"]) < len(fd.get("groupby") or []) + len( - fd.get("columns") or [] - ): - raise Exception(_("Can't have overlap between Series and Breakdowns")) - if not fd.get("metrics"): + if not self.all_metrics: raise Exception(_("Pick at least one metric")) - if not fd.get("groupby"): + if not self.columns: raise Exception(_("Pick at least one field for [Series]")) return d @@ -1630,22 +1648,22 @@ def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data metrics = self.metric_labels columns = fd.get("columns") or [] + groupby = fd.get("groupby") or [] # pandas will throw away nulls when grouping/pivoting, # so we substitute NULL_STRING for any nulls in the necessary columns - filled_cols = self.groupby + columns - df[filled_cols] = df[filled_cols].fillna(value=NULL_STRING) + df[self.columns] = df[self.columns].fillna(value=NULL_STRING) - row = df.groupby(self.groupby).sum()[metrics[0]].copy() + row = df.groupby(self.columns).sum()[metrics[0]].copy() row.sort_values(ascending=False, inplace=True) - pt = df.pivot_table(index=self.groupby, columns=columns, values=metrics) + pt = df.pivot_table(index=groupby, columns=columns, values=metrics) if fd.get("contribution"): pt = pt.T pt = (pt / pt.sum()).T pt = pt.reindex(row.index) chart_data = [] for name, ys in pt.items(): - if pt[name].dtype.kind not in "biufc" or name in self.groupby: + if pt[name].dtype.kind not in "biufc" or name in groupby: continue if isinstance(name, str): series_title = name @@ -1714,13 +1732,6 @@ class SankeyViz(BaseViz): is_timeseries = False credits = 'd3-sankey on npm' - def query_obj(self): - qry = super().query_obj() - if len(qry["groupby"]) != 2: - raise Exception(_("Pick exactly 2 columns as [Source / Target]")) - qry["metrics"] = [self.form_data["metric"]] - return qry - def get_data(self, df: pd.DataFrame) -> VizData: df.columns = ["source", "target", "value"] df["source"] = df["source"].astype(str) @@ -1791,7 +1802,6 @@ class ChordViz(BaseViz): def query_obj(self): qry = super().query_obj() fd = self.form_data - qry["groupby"] = [fd.get("groupby"), fd.get("columns")] qry["metrics"] = [fd.get("metric")] return qry @@ -1821,12 +1831,6 @@ class CountryMapViz(BaseViz): is_timeseries = False credits = "From bl.ocks.org By john-guerra" - def query_obj(self): - qry = super().query_obj() - qry["metrics"] = [self.form_data["metric"]] - qry["groupby"] = [self.form_data["entity"]] - return qry - def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data cols = [fd.get("entity")] @@ -1848,11 +1852,6 @@ class WorldMapViz(BaseViz): is_timeseries = False credits = 'datamaps on npm' - def query_obj(self): - qry = super().query_obj() - qry["groupby"] = [self.form_data["entity"]] - return qry - def get_data(self, df: pd.DataFrame) -> VizData: from superset.examples import countries @@ -1915,7 +1914,7 @@ def run_extra_queries(self): raise Exception( _("Invalid filter configuration, please select a column") ) - qry["groupby"] = [col] + qry["columns"] = [col] metric = flt.get("metric") qry["metrics"] = [metric] if metric else [] df = self.get_df_payload(query_obj=qry).get("df") @@ -1981,12 +1980,6 @@ class ParallelCoordinatesViz(BaseViz): ) is_timeseries = False - def query_obj(self): - d = super().query_obj() - fd = self.form_data - d["groupby"] = [fd.get("series")] - return d - def get_data(self, df: pd.DataFrame) -> VizData: return df.to_dict(orient="records") @@ -2003,13 +1996,6 @@ class HeatmapViz(BaseViz): "bl.ocks.org" ) - def query_obj(self): - d = super().query_obj() - fd = self.form_data - d["metrics"] = [fd.get("metric")] - d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] - return d - def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None @@ -2227,21 +2213,6 @@ def get_metrics(self): self.metric = self.form_data.get("size") return [self.metric] if self.metric else [] - def process_spatial_query_obj(self, key, group_by): - group_by.extend(self.get_spatial_columns(key)) - - def get_spatial_columns(self, key): - spatial = self.form_data.get(key) - if spatial is None: - raise ValueError(_("Bad spatial key")) - - if spatial.get("type") == "latlong": - return [spatial.get("lonCol"), spatial.get("latCol")] - elif spatial.get("type") == "delimited": - return [spatial.get("lonlatCol")] - elif spatial.get("type") == "geohash": - return [spatial.get("geohashCol")] - @staticmethod def parse_coordinates(s): if not s: @@ -2296,9 +2267,6 @@ def process_spatial_data_obj(self, key, df): def add_null_filters(self): fd = self.form_data spatial_columns = set() - for key in self.spatial_control_keys: - for column in self.get_spatial_columns(key): - spatial_columns.add(column) if fd.get("adhoc_filters") is None: fd["adhoc_filters"] = [] @@ -2321,9 +2289,6 @@ def query_obj(self): d = super().query_obj() gb = [] - for key in self.spatial_control_keys: - self.process_spatial_query_obj(key, gb) - if fd.get("dimension"): gb += [fd.get("dimension")] @@ -2332,11 +2297,7 @@ def query_obj(self): metrics = self.get_metrics() gb = list(set(gb)) if metrics: - d["groupby"] = gb d["metrics"] = metrics - d["columns"] = [] - else: - d["columns"] = gb return d def get_js_columns(self, d): @@ -2488,13 +2449,10 @@ def query_obj(self): self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") d = super().query_obj() self.metric = fd.get("metric") - line_col = fd.get("line_column") if d["metrics"]: self.has_metrics = True - d["groupby"].append(line_col) else: self.has_metrics = False - d["columns"].append(line_col) return d def get_properties(self, d): @@ -2573,7 +2531,6 @@ def query_obj(self): d = super().query_obj() d["columns"] += [self.form_data.get("geojson")] d["metrics"] = [] - d["groupby"] = [] return d def get_properties(self, d): From 7c77dc93b671ef52855dc678fd5d343427c798a7 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 27 Mar 2020 09:24:06 +0200 Subject: [PATCH 02/14] Fix query_object bug --- superset/common/query_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 2cb8728af77d..40bef9012a25 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -108,7 +108,7 @@ def __init__( if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras: self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={}) - self.columns = columns + self.columns = columns or [] if groupby: self.columns += groupby logger.warning( From 9489149ddff02c477b22bdbb5fa80ae2158a7236 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 27 Mar 2020 12:03:11 +0200 Subject: [PATCH 03/14] Fix histogram --- superset/connectors/druid/models.py | 9 ++++----- superset/viz.py | 8 +++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 73ca6d5f2cf6..7627ea89f5dc 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1214,7 +1214,7 @@ def run_query( # druid order_direction = "descending" if order_desc else "ascending" - if columns: + if not metrics: columns.append("__time") del qry["post_aggregations"] del qry["aggregations"] @@ -1364,7 +1364,7 @@ def run_query( # druid return query_str @staticmethod - def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFrame: + def homogenize_types(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: """Converting all GROUPBY columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns @@ -1374,7 +1374,7 @@ def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFr Here we replace None with and make the whole series a str instead of an object. """ - df[groupby_cols] = df[groupby_cols].fillna(NULL_STRING).astype("unicode") + df[columns] = df[columns].fillna(NULL_STRING).astype("unicode") return df def query(self, query_obj: Dict) -> QueryResult: @@ -1390,7 +1390,7 @@ def query(self, query_obj: Dict) -> QueryResult: df=df, query=query_str, duration=datetime.now() - qry_start_dttm ) - df = self.homogenize_types(df, query_obj.get("groupby", [])) + df = self.homogenize_types(df, query_obj.get("columns", [])) df.columns = [ DTTM_ALIAS if c in ("timestamp", "__time") else c for c in df.columns ] @@ -1405,7 +1405,6 @@ def query(self, query_obj: Dict) -> QueryResult: cols: List[str] = [] if DTTM_ALIAS in df.columns: cols += [DTTM_ALIAS] - cols += query_obj.get("groupby") or [] cols += query_obj.get("columns") or [] cols += query_obj.get("metrics") or [] diff --git a/superset/viz.py b/superset/viz.py index de71c53172e3..0e0da51e21b3 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -1596,19 +1596,21 @@ def labelify(self, keys, column): keys = (keys,) # removing undesirable characters labels = [re.sub(r"\W+", r"_", k) for k in keys] - if len(self.columns) > 1 or not self.groupby: + if len(self.columns) > 1: # Only show numeric column in label if there are many labels = [column] + labels return "__".join(labels) def get_data(self, df: pd.DataFrame) -> VizData: """Returns the chart data""" + groupby = self.form_data.get("groupby") + if df.empty: return None chart_data = [] - if len(self.groupby) > 0: - groups = df.groupby(self.groupby) + if groupby: + groups = df.groupby(groupby) else: groups = [((), df)] for keys, data in groups: From e9d83cc6927dbe04df88547d261c4eef5bb8fd09 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 27 Mar 2020 14:58:45 +0200 Subject: [PATCH 04/14] Remove groupby from legacy druid connector and fix first batch of unit tests --- superset/connectors/druid/models.py | 21 ++++++++++---------- superset/viz.py | 2 ++ tests/druid_func_tests.py | 30 ++++++++++++++--------------- tests/model_tests.py | 3 +-- tests/security_tests.py | 2 -- tests/sqla_models_tests.py | 4 ++-- 6 files changed, 30 insertions(+), 32 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 7627ea89f5dc..a28863893dfa 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1082,11 +1082,11 @@ def get_aggregations( return aggregations def get_dimensions( - self, groupby: List[str], columns_dict: Dict[str, DruidColumn] + self, columns: List[str], columns_dict: Dict[str, DruidColumn] ) -> List[Union[str, Dict]]: dimensions = [] - groupby = [gb for gb in groupby if gb in columns_dict] - for column_name in groupby: + columns = [gb for gb in columns if gb in columns_dict] + for column_name in columns: col = columns_dict.get(column_name) dim_spec = col.dimension_spec if col else None dimensions.append(dim_spec or column_name) @@ -1137,7 +1137,7 @@ def sanitize_metric_object(metric: Dict) -> None: def run_query( # druid self, - groupby, + columns, metrics, granularity, from_dttm, @@ -1151,7 +1151,6 @@ def run_query( # druid inner_to_dttm=None, orderby=None, extras=None, - columns=None, phase=2, client=None, order_desc=True, @@ -1188,7 +1187,7 @@ def run_query( # druid ) # the dimensions list with dimensionSpecs expanded - dimensions = self.get_dimensions(groupby, columns_dict) + dimensions = self.get_dimensions(columns, columns_dict) extras = extras or {} qry = dict( datasource=self.datasource_name, @@ -1214,7 +1213,7 @@ def run_query( # druid order_direction = "descending" if order_desc else "ascending" - if not metrics: + if not metrics and "__time" not in columns: columns.append("__time") del qry["post_aggregations"] del qry["aggregations"] @@ -1224,11 +1223,11 @@ def run_query( # druid qry["granularity"] = "all" qry["limit"] = row_limit client.scan(**qry) - elif len(groupby) == 0 and not having_filters: + elif not columns and not having_filters: logger.info("Running timeseries query for no groupby values") del qry["dimensions"] client.timeseries(**qry) - elif not having_filters and len(groupby) == 1 and order_desc: + elif not having_filters and len(columns) == 1 and order_desc: dim = list(qry["dimensions"])[0] logger.info("Running two-phase topn query for dimension [{}]".format(dim)) pre_qry = deepcopy(qry) @@ -1279,7 +1278,7 @@ def run_query( # druid qry["metric"] = list(qry["aggregations"].keys())[0] client.topn(**qry) logger.info("Phase 2 Complete") - elif len(groupby) > 0 or having_filters: + elif columns or having_filters: # If grouping on multiple fields or using a having filter # we have to force a groupby query logger.info("Running groupby query for dimensions [{}]".format(dimensions)) @@ -1365,7 +1364,7 @@ def run_query( # druid @staticmethod def homogenize_types(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: - """Converting all GROUPBY columns to strings + """Converting all columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns strings in the dataframe. This creates issues downstream related diff --git a/superset/viz.py b/superset/viz.py index 0e0da51e21b3..3ac5923f7a45 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -1635,6 +1635,8 @@ class DistributionBarViz(DistributionPieViz): is_timeseries = False def query_obj(self): + # TODO: Refactor this plugin to either perform grouping or assume + # preaggretagion of metrics ("numeric columns") d = super().query_obj() fd = self.form_data if not self.all_metrics: diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 1ed7057bc791..5895a1b386b0 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -407,7 +407,7 @@ def test_run_query_no_groupby(self): aggs = [] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - groupby = [] + columns = [] metrics = ["metric1"] ds.get_having_filters = Mock(return_value=[]) client.query_builder = Mock() @@ -415,7 +415,7 @@ def test_run_query_no_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - groupby, + columns, metrics, None, from_dttm, @@ -456,7 +456,7 @@ def test_run_query_with_adhoc_metric(self): all_metrics = [] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) - groupby = [] + columns = [] metrics = [ { "expressionType": "SIMPLE", @@ -472,7 +472,7 @@ def test_run_query_with_adhoc_metric(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - groupby, + columns, metrics, None, from_dttm, @@ -513,13 +513,13 @@ def test_run_query_single_groupby(self): aggs = ["metric1"] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - groupby = ["col1"] + columns = ["col1"] metrics = ["metric1"] ds.get_having_filters = Mock(return_value=[]) client.query_builder.last_query.query_dict = {"mock": 0} # client.topn is called twice ds.run_query( - groupby, + columns, metrics, None, from_dttm, @@ -543,7 +543,7 @@ def test_run_query_single_groupby(self): client = Mock() client.query_builder.last_query.query_dict = {"mock": 0} ds.run_query( - groupby, + columns, metrics, None, from_dttm, @@ -611,7 +611,7 @@ def test_run_query_multiple_groupby(self): aggs = [] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - groupby = ["col1", "col2"] + columns = ["col1", "col2"] metrics = ["metric1"] ds.get_having_filters = Mock(return_value=[]) client.query_builder = Mock() @@ -619,7 +619,7 @@ def test_run_query_multiple_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - groupby, + columns, metrics, None, from_dttm, @@ -1016,12 +1016,12 @@ def test_run_query_order_by_metrics(self): ds.columns = [dim1, dim2] ds.metrics = list(metrics_dict.values()) - groupby = ["dim1"] + columns = ["dim1"] metrics = ["count1"] granularity = "all" # get the counts of the top 5 'dim1's, order by 'sum1' ds.run_query( - groupby, + columns, metrics, granularity, from_dttm, @@ -1042,7 +1042,7 @@ def test_run_query_order_by_metrics(self): # get the counts of the top 5 'dim1's, order by 'div1' ds.run_query( - groupby, + columns, metrics, granularity, from_dttm, @@ -1061,10 +1061,10 @@ def test_run_query_order_by_metrics(self): self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) self.assertEqual({"div1"}, set(post_aggregations.keys())) - groupby = ["dim1", "dim2"] + columns = ["dim1", "dim2"] # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' ds.run_query( - groupby, + columns, metrics, granularity, from_dttm, @@ -1085,7 +1085,7 @@ def test_run_query_order_by_metrics(self): # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' ds.run_query( - groupby, + columns, metrics, granularity, from_dttm, diff --git a/tests/model_tests.py b/tests/model_tests.py index 4a203a96475d..43a708f5cb37 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -234,11 +234,10 @@ def query_with_expr_helper(self, is_timeseries, inner_join=True): label="arbitrary", expressionType="SQL", sqlExpression="COUNT(1)" ) query_obj = dict( - groupby=[arbitrary_gby, "name"], + columns=[arbitrary_gby, "name"], metrics=[arbitrary_metric], filter=[], is_timeseries=is_timeseries, - columns=[], granularity="ds", from_dttm=None, to_dttm=None, diff --git a/tests/security_tests.py b/tests/security_tests.py index 7b8df262fea4..183a059d97ac 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -853,7 +853,6 @@ def test_rls_filter_alters_query(self): ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("birth_names") query_obj = dict( - groupby=[], metrics=[], filter=[], is_timeseries=False, @@ -872,7 +871,6 @@ def test_rls_filter_doesnt_alter_query(self): ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("birth_names") query_obj = dict( - groupby=[], metrics=[], filter=[], is_timeseries=False, diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 700c2e2bb94a..85296d09fe85 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -79,7 +79,7 @@ def test_has_extra_cache_keys(self): "granularity": None, "from_dttm": None, "to_dttm": None, - "groupby": ["user"], + "columns": ["user"], "metrics": [], "is_timeseries": False, "filter": [], @@ -100,7 +100,7 @@ def test_has_no_extra_cache_keys(self): "granularity": None, "from_dttm": None, "to_dttm": None, - "groupby": ["user"], + "columns": ["user"], "metrics": [], "is_timeseries": False, "filter": [], From 104dbaed56620816388323659487b18fa9b72f5d Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 27 Mar 2020 18:38:42 +0200 Subject: [PATCH 05/14] Deprecate some unnecessary tests + fix a few others --- superset/common/query_object.py | 1 - superset/viz.py | 5 +- tests/core_tests.py | 2 +- tests/model_tests.py | 3 +- tests/viz_tests.py | 114 +------------------------------- 5 files changed, 7 insertions(+), 118 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 40bef9012a25..389f8b3f100a 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -125,7 +125,6 @@ def to_dict(self) -> Dict[str, Any]: "from_dttm": self.from_dttm, "to_dttm": self.to_dttm, "is_timeseries": self.is_timeseries, - "groupby": self.groupby, "metrics": self.metrics, "row_limit": self.row_limit, "filter": self.filter, diff --git a/superset/viz.py b/superset/viz.py index 3ac5923f7a45..864f0159485a 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -86,7 +86,6 @@ "dimension", "entity", "groupby", - "order_by_cols", "series", "line_column", "js_columns", @@ -1651,6 +1650,8 @@ def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data metrics = self.metric_labels + # TODO: will require post transformation logic not currently available in + # /api/v1/query endpoint columns = fd.get("columns") or [] groupby = fd.get("groupby") or [] @@ -1658,7 +1659,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: # so we substitute NULL_STRING for any nulls in the necessary columns df[self.columns] = df[self.columns].fillna(value=NULL_STRING) - row = df.groupby(self.columns).sum()[metrics[0]].copy() + row = df.groupby(groupby).sum()[metrics[0]].copy() row.sort_values(ascending=False, inplace=True) pt = df.pivot_table(index=groupby, columns=columns, values=metrics) if fd.get("contribution"): diff --git a/tests/core_tests.py b/tests/core_tests.py index 73711d9301d3..ce1668e5f525 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -119,7 +119,7 @@ def _get_query_context(self) -> Dict[str, Any]: "queries": [ { "granularity": "ds", - "groupby": ["name"], + "columns": ["name"], "metrics": [{"label": "sum__num"}], "filters": [], "row_limit": 100, diff --git a/tests/model_tests.py b/tests/model_tests.py index 43a708f5cb37..52378eb2bfe1 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -277,11 +277,10 @@ def test_query_with_expr_groupby(self): def test_sql_mutator(self): tbl = self.get_table_by_name("birth_names") query_obj = dict( - groupby=[], + columns=["name"], metrics=[], filter=[], is_timeseries=False, - columns=["name"], granularity=None, from_dttm=None, to_dttm=None, diff --git a/tests/viz_tests.py b/tests/viz_tests.py index ab10f4970665..3f48696f6e8f 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -373,21 +373,15 @@ def test_query_obj_throws_columns_and_metrics(self, super_query_obj): with self.assertRaises(Exception): test_viz.query_obj() - @patch("superset.viz.BaseViz.query_obj") - def test_query_obj_merges_all_columns(self, super_query_obj): + def test_query_obj_merges_all_columns(self): datasource = self.get_datasource_mock() form_data = { "all_columns": ["colA", "colB", "colC"], "order_by_cols": ['["colA", "colB"]', '["colC"]'], } - super_query_obj.return_value = { - "columns": ["colD", "colC"], - "groupby": ["colA", "colB"], - } test_viz = viz.TableViz(datasource, form_data) query_obj = test_viz.query_obj() self.assertEqual(form_data["all_columns"], query_obj["columns"]) - self.assertEqual([], query_obj["groupby"]) self.assertEqual([["colA", "colB"], ["colC"]], query_obj["orderby"]) @patch("superset.viz.BaseViz.query_obj") @@ -948,18 +942,6 @@ def test_query_obj_throws_metrics_and_groupby(self, super_query_obj): class BaseDeckGLVizTestCase(SupersetTestCase): - def test_get_metrics(self): - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) - result = test_viz_deckgl.get_metrics() - assert result == [form_data.get("size")] - - form_data = {} - test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) - result = test_viz_deckgl.get_metrics() - assert result == [] - def test_scatterviz_get_metrics(self): form_data = load_fixture("deck_path_form_data.json") datasource = self.get_datasource_mock() @@ -996,36 +978,6 @@ def test_get_properties(self): self.assertTrue("" in str(context.exception)) - def test_process_spatial_query_obj(self): - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - mock_key = "spatial_key" - mock_gb = [] - test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) - - with self.assertRaises(ValueError) as context: - test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb) - - self.assertTrue("Bad spatial key" in str(context.exception)) - - test_form_data = { - "latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"}, - "delimited_key": {"type": "delimited", "lonlatCol": "lonlat"}, - "geohash_key": {"type": "geohash", "geohashCol": "geo"}, - } - - datasource = self.get_datasource_mock() - expected_results = { - "latlong_key": ["lon", "lat"], - "delimited_key": ["lonlat"], - "geohash_key": ["geo"], - } - for mock_key in ["latlong_key", "delimited_key", "geohash_key"]: - mock_gb = [] - test_viz_deckgl = viz.BaseDeckGLViz(datasource, test_form_data) - test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb) - assert expected_results.get(mock_key) == mock_gb - def test_geojson_query_obj(self): form_data = load_fixture("deck_geojson_form_data.json") datasource = self.get_datasource_mock() @@ -1033,8 +985,7 @@ def test_geojson_query_obj(self): results = test_viz_deckgl.query_obj() assert results["metrics"] == [] - assert results["groupby"] == [] - assert results["columns"] == ["test_col"] + assert results["columns"] == ["color", "test_col"] def test_parse_coordinates(self): form_data = load_fixture("deck_path_form_data.json") @@ -1062,67 +1013,6 @@ def test_parse_coordinates_raises(self): with self.assertRaises(SpatialException): test_viz_deckgl.parse_coordinates("fldkjsalkj,fdlaskjfjadlksj") - @patch("superset.utils.core.uuid.uuid4") - def test_filter_nulls(self, mock_uuid4): - mock_uuid4.return_value = uuid.UUID("12345678123456781234567812345678") - test_form_data = { - "latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"}, - "delimited_key": {"type": "delimited", "lonlatCol": "lonlat"}, - "geohash_key": {"type": "geohash", "geohashCol": "geo"}, - } - - datasource = self.get_datasource_mock() - expected_results = { - "latlong_key": [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "filterOptionName": "12345678-1234-5678-1234-567812345678", - "comparator": "", - "operator": "IS NOT NULL", - "subject": "lat", - "isExtra": False, - }, - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "filterOptionName": "12345678-1234-5678-1234-567812345678", - "comparator": "", - "operator": "IS NOT NULL", - "subject": "lon", - "isExtra": False, - }, - ], - "delimited_key": [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "filterOptionName": "12345678-1234-5678-1234-567812345678", - "comparator": "", - "operator": "IS NOT NULL", - "subject": "lonlat", - "isExtra": False, - } - ], - "geohash_key": [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "filterOptionName": "12345678-1234-5678-1234-567812345678", - "comparator": "", - "operator": "IS NOT NULL", - "subject": "geo", - "isExtra": False, - } - ], - } - for mock_key in ["latlong_key", "delimited_key", "geohash_key"]: - test_viz_deckgl = viz.BaseDeckGLViz(datasource, test_form_data.copy()) - test_viz_deckgl.spatial_control_keys = [mock_key] - test_viz_deckgl.add_null_filters() - adhoc_filters = test_viz_deckgl.form_data["adhoc_filters"] - assert expected_results.get(mock_key) == adhoc_filters - class TimeSeriesVizTestCase(SupersetTestCase): def test_timeseries_unicode_data(self): From 1f70bab50181439515c4d6e00effe588d5ce2560 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 1 Apr 2020 00:34:56 +0300 Subject: [PATCH 06/14] Address comments --- superset/common/query_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 389f8b3f100a..aac187b1bbb4 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -116,7 +116,6 @@ def __init__( f"pass all selectables via the columns field" ) - self.columns = columns or [] self.orderby = orderby or [] def to_dict(self) -> Dict[str, Any]: From cf75829641bb6526282ca6618e4300b90d550287 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 3 Apr 2020 14:32:50 +0300 Subject: [PATCH 07/14] hide SIP-38 changes behind feature flag --- superset/config.py | 1 + superset/connectors/druid/models.py | 59 +- superset/connectors/sqla/models.py | 53 +- superset/models/slice.py | 6 +- superset/views/utils.py | 8 +- superset/viz.py | 174 +- superset/viz_sip38.py | 2857 +++++++++++++++++++++++++++ 7 files changed, 3070 insertions(+), 88 deletions(-) create mode 100644 superset/viz_sip38.py diff --git a/superset/config.py b/superset/config.py index 28beb276a31a..91866fbfe2d0 100644 --- a/superset/config.py +++ b/superset/config.py @@ -287,6 +287,7 @@ def _try_json_readsha(filepath, length): # pylint: disable=unused-argument "PRESTO_EXPAND_DATA": False, "REDUCE_DASHBOARD_BOOTSTRAP_PAYLOAD": False, "SHARE_QUERIES_VIA_KV_STORE": False, + "SIP_38_VIZ_REARCHITECTURE": False, "TAGGING_SYSTEM": False, "SQLLAB_BACKEND_PERSISTENCE": False, "LIST_VIEWS_NEW_UI": False, diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index a28863893dfa..4f36987a4d09 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -48,7 +48,7 @@ from sqlalchemy.orm import backref, relationship, Session from sqlalchemy_utils import EncryptedType -from superset import conf, db, security_manager +from superset import conf, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.constants import NULL_STRING from superset.exceptions import SupersetException @@ -84,6 +84,7 @@ except ImportError: pass +IS_SIP_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") DRUID_TZ = conf.get("DRUID_TZ") POST_AGG_TYPE = "postagg" metadata = Model.metadata # pylint: disable=no-member @@ -1085,7 +1086,7 @@ def get_dimensions( self, columns: List[str], columns_dict: Dict[str, DruidColumn] ) -> List[Union[str, Dict]]: dimensions = [] - columns = [gb for gb in columns if gb in columns_dict] + columns = [col for col in columns if col in columns_dict] for column_name in columns: col = columns_dict.get(column_name) dim_spec = col.dimension_spec if col else None @@ -1137,11 +1138,12 @@ def sanitize_metric_object(metric: Dict) -> None: def run_query( # druid self, - columns, metrics, granularity, from_dttm, to_dttm, + columns=None, + groupby=None, filter=None, is_timeseries=True, timeseries_limit=None, @@ -1187,7 +1189,12 @@ def run_query( # druid ) # the dimensions list with dimensionSpecs expanded - dimensions = self.get_dimensions(columns, columns_dict) + + if IS_SIP_38: + dimensions = self.get_dimensions(columns, columns_dict) + else: + dimensions = self.get_dimensions(groupby, columns_dict) + extras = extras or {} qry = dict( datasource=self.datasource_name, @@ -1213,7 +1220,13 @@ def run_query( # druid order_direction = "descending" if order_desc else "ascending" - if not metrics and "__time" not in columns: + if ( + IS_SIP_38 + and not metrics + and "__time" not in columns + or not IS_SIP_38 + and columns + ): columns.append("__time") del qry["post_aggregations"] del qry["aggregations"] @@ -1223,11 +1236,26 @@ def run_query( # druid qry["granularity"] = "all" qry["limit"] = row_limit client.scan(**qry) - elif not columns and not having_filters: + elif ( + IS_SIP_38 + and columns + or not IS_SIP_38 + and len(groupby) == 0 + and not having_filters + ): logger.info("Running timeseries query for no groupby values") del qry["dimensions"] client.timeseries(**qry) - elif not having_filters and len(columns) == 1 and order_desc: + elif ( + IS_SIP_38 + and not having_filters + and len(columns) == 1 + and order_desc + or not IS_SIP_38 + and not having_filters + and len(groupby) == 1 + and order_desc + ): dim = list(qry["dimensions"])[0] logger.info("Running two-phase topn query for dimension [{}]".format(dim)) pre_qry = deepcopy(qry) @@ -1278,7 +1306,13 @@ def run_query( # druid qry["metric"] = list(qry["aggregations"].keys())[0] client.topn(**qry) logger.info("Phase 2 Complete") - elif columns or having_filters: + elif ( + having_filters + or IS_SIP_38 + and columns + or not IS_SIP_38 + and len(groupby) > 0 + ): # If grouping on multiple fields or using a having filter # we have to force a groupby query logger.info("Running groupby query for dimensions [{}]".format(dimensions)) @@ -1364,7 +1398,7 @@ def run_query( # druid @staticmethod def homogenize_types(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: - """Converting all columns to strings + """Converting all GROUPBY columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns strings in the dataframe. This creates issues downstream related @@ -1389,7 +1423,9 @@ def query(self, query_obj: Dict) -> QueryResult: df=df, query=query_str, duration=datetime.now() - qry_start_dttm ) - df = self.homogenize_types(df, query_obj.get("columns", [])) + df = self.homogenize_types( + df, query_obj.get("columns" if IS_SIP_38 else "groupby", []) + ) df.columns = [ DTTM_ALIAS if c in ("timestamp", "__time") else c for c in df.columns ] @@ -1404,6 +1440,9 @@ def query(self, query_obj: Dict) -> QueryResult: cols: List[str] = [] if DTTM_ALIAS in df.columns: cols += [DTTM_ALIAS] + + if not IS_SIP_38: + cols += query_obj.get("groupby") or [] cols += query_obj.get("columns") or [] cols += query_obj.get("metrics") or [] diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 93db4b50f4ce..8303d2f61c02 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -49,7 +49,7 @@ from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql.expression import Label, Select, TextAsFrom -from superset import app, db, security_manager +from superset import app, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression @@ -65,6 +65,9 @@ logger = logging.getLogger(__name__) +IS_SIP_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") + + class SqlaQuery(NamedTuple): extra_cache_keys: List[Any] labels_expected: List[str] @@ -700,6 +703,8 @@ def get_sqla_query( # sqla granularity, from_dttm, to_dttm, + columns=None, + groupby=None, filter=None, is_timeseries=True, timeseries_limit=15, @@ -709,12 +714,12 @@ def get_sqla_query( # sqla inner_to_dttm=None, orderby=None, extras=None, - columns=None, order_desc=True, ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { "from_dttm": from_dttm, + "groupby": groupby, "metrics": metrics, "row_limit": row_limit, "to_dttm": to_dttm, @@ -747,7 +752,15 @@ def get_sqla_query( # sqla "and is required by this type of chart" ) ) - if not metrics and not columns: + if ( + IS_SIP_38 + and not metrics + and not columns + or not IS_SIP_38 + and not groupby + and not metrics + and not columns + ): raise Exception(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] for m in metrics: @@ -766,9 +779,12 @@ def get_sqla_query( # sqla select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() - if metrics and columns: + if IS_SIP_38 and metrics and columns or not IS_SIP_38 and groupby: # dedup columns while preserving order - groupby = list(dict.fromkeys(columns)) + if IS_SIP_38: + groupby = list(dict.fromkeys(columns)) + else: + groupby = list(dict.fromkeys(groupby)) select_exprs = [] for s in groupby: @@ -827,7 +843,7 @@ def get_sqla_query( # sqla tbl = self.get_from_clause(template_processor) - if metrics: + if IS_SIP_38 and metrics or not IS_SIP_38 and not columns: qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] @@ -890,7 +906,14 @@ def get_sqla_query( # sqla qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) - if not orderby and metrics: + if ( + IS_SIP_38 + and not orderby + and metrics + or not IS_SIP_38 + and not orderby + and not columns + ): orderby = [(main_metric_expr, not order_desc)] # To ensure correct handling of the ORDER BY labeling we need to reference the @@ -912,7 +935,18 @@ def get_sqla_query( # sqla if row_limit: qry = qry.limit(row_limit) - if is_timeseries and timeseries_limit and columns and not time_groupby_inline: + if ( + IS_SIP_38 + and is_timeseries + and timeseries_limit + and columns + and not time_groupby_inline + or not IS_SIP_38 + and is_timeseries + and timeseries_limit + and groupby + and not time_groupby_inline + ): if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, @@ -980,6 +1014,9 @@ def get_sqla_query( # sqla "columns": columns, "order_desc": True, } + if not IS_SIP_38: + prequery_obj["groupby"] = groupby + result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ diff --git a/superset/models/slice.py b/superset/models/slice.py index 6cfadab07a95..ad0ae5dfbfa4 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -31,7 +31,11 @@ from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.tags import ChartUpdater from superset.utils import core as utils -from superset.viz import BaseViz, viz_types + +if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): + from superset.viz_sip38 import BaseViz, viz_types +else: + from superset.viz import BaseViz, viz_types if TYPE_CHECKING: # pylint: disable=unused-import diff --git a/superset/views/utils.py b/superset/views/utils.py index 0f8629df87b8..780dd99af005 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -23,7 +23,7 @@ from flask import request import superset.models.core as models -from superset import app, db, viz +from superset import app, db, is_feature_enabled from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import SupersetException from superset.legacy import update_time_range @@ -31,6 +31,12 @@ from superset.models.slice import Slice from superset.utils.core import QueryStatus, TimeRangeEndpoint +if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): + from superset import viz_sip38 as viz +else: + from superset import viz + + FORM_DATA_KEY_BLACKLIST: List[str] = [] if not app.config["ENABLE_JAVASCRIPT_CONTROLS"]: FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"] diff --git a/superset/viz.py b/superset/viz.py index 864f0159485a..848e3b35c9e4 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -78,21 +78,6 @@ "size", ] -COLUMN_FORM_DATA_PARAMS = [ - "all_columns", - "all_columns_x", - "all_columns_y", - "columns", - "dimension", - "entity", - "groupby", - "series", - "line_column", - "js_columns", -] - -SPATIAL_COLUMN_FORM_DATA_PARAMS = ["spatial", "start_spatial", "end_spatial"] - class BaseViz: @@ -122,34 +107,7 @@ def __init__( self.query = "" self.token = self.form_data.get("token", "token_" + uuid.uuid4().hex[:8]) - # merge all selectable columns into `columns` property - self.columns: List[str] = [] - for key in COLUMN_FORM_DATA_PARAMS: - logger.warning( - f"The form field %s is deprecated. Viz plugins should " - f"pass all selectables via the columns field", - key, - ) - value = self.form_data.get(key) or [] - value_list = value if isinstance(value, list) else [value] - self.columns += value_list - - for key in SPATIAL_COLUMN_FORM_DATA_PARAMS: - spatial = self.form_data.get(key) - if not isinstance(spatial, dict): - continue - logger.warning( - f"The form field %s is deprecated. Viz plugins should " - f"pass all selectables via the columns field", - key, - ) - if spatial.get("type") == "latlong": - self.columns += [spatial["lonCol"], spatial["latCol"]] - elif spatial.get("type") == "delimited": - self.columns.append(spatial["lonlatCol"]) - elif spatial.get("type") == "geohash": - self.columns.append(spatial["geohashCol"]) - + self.groupby = self.form_data.get("groupby") or [] self.time_shift = timedelta() self.status: Optional[str] = None @@ -246,6 +204,7 @@ def get_samples(self): query_obj = self.query_obj() query_obj.update( { + "groupby": [], "metrics": [], "row_limit": 1000, "columns": [o.column_name for o in self.datasource.columns], @@ -332,12 +291,14 @@ def query_obj(self) -> Dict[str, Any]: """Building a query object""" form_data = self.form_data self.process_query_filters() + gb = form_data.get("groupby") or [] metrics = self.all_metrics or [] - columns = self.columns + columns = form_data.get("columns") or [] + groupby = list(set(gb + columns)) is_timeseries = self.is_timeseries - if DTTM_ALIAS in columns: - columns.remove(DTTM_ALIAS) + if DTTM_ALIAS in groupby: + groupby.remove(DTTM_ALIAS) is_timeseries = True granularity = form_data.get("granularity") or form_data.get("granularity_sqla") @@ -381,7 +342,7 @@ def query_obj(self) -> Dict[str, Any]: "from_dttm": from_dttm, "to_dttm": to_dttm, "is_timeseries": is_timeseries, - "columns": columns, + "groupby": groupby, "metrics": metrics, "row_limit": row_limit, "filter": self.form_data.get("filters", []), @@ -606,6 +567,8 @@ def query_obj(self): sort_by = fd.get("timeseries_limit_metric") if fd.get("all_columns"): + d["columns"] = fd.get("all_columns") + d["groupby"] = [] order_by_cols = fd.get("order_by_cols") or [] d["orderby"] = [json.loads(t) for t in order_by_cols] elif sort_by: @@ -853,6 +816,11 @@ class WordCloudViz(BaseViz): verbose_name = _("Word Cloud") is_timeseries = False + def query_obj(self): + d = super().query_obj() + d["groupby"] = [self.form_data.get("series")] + return d + class TreemapViz(BaseViz): @@ -1056,6 +1024,12 @@ class BubbleViz(NVD3Viz): def query_obj(self): form_data = self.form_data d = super().query_obj() + d["groupby"] = [form_data.get("entity")] + if form_data.get("series"): + d["groupby"].append(form_data.get("series")) + + # dedup groupby if it happens to be the same + d["groupby"] = list(dict.fromkeys(d["groupby"])) self.x_metric = form_data.get("x") self.y_metric = form_data.get("y") @@ -1266,7 +1240,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: if aggregate: df = df.pivot_table( index=DTTM_ALIAS, - columns=self.columns, + columns=fd.get("groupby"), values=self.metric_labels, fill_value=0, aggfunc=sum, @@ -1274,7 +1248,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: else: df = df.pivot_table( index=DTTM_ALIAS, - columns=self.columns, + columns=fd.get("groupby"), values=self.metric_labels, fill_value=self.pivot_fill_value, ) @@ -1567,7 +1541,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None metric = self.metric_labels[0] - df = df.pivot_table(index=self.columns, values=[metric]) + df = df.pivot_table(index=self.groupby, values=[metric]) df.sort_values(by=metric, ascending=False, inplace=True) df = df.reset_index() df.columns = ["x", "y"] @@ -1586,8 +1560,13 @@ def query_obj(self): """Returns the query object for this visualization""" d = super().query_obj() d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) - if not self.form_data.get("all_columns_x"): + numeric_columns = self.form_data.get("all_columns_x") + if numeric_columns is None: raise Exception(_("Must have at least one numeric column specified")) + self.columns = numeric_columns + d["columns"] = numeric_columns + self.groupby + # override groupby entry to avoid aggregation + d["groupby"] = [] return d def labelify(self, keys, column): @@ -1595,21 +1574,19 @@ def labelify(self, keys, column): keys = (keys,) # removing undesirable characters labels = [re.sub(r"\W+", r"_", k) for k in keys] - if len(self.columns) > 1: + if len(self.columns) > 1 or not self.groupby: # Only show numeric column in label if there are many labels = [column] + labels return "__".join(labels) def get_data(self, df: pd.DataFrame) -> VizData: """Returns the chart data""" - groupby = self.form_data.get("groupby") - if df.empty: return None chart_data = [] - if groupby: - groups = df.groupby(groupby) + if len(self.groupby) > 0: + groups = df.groupby(self.groupby) else: groups = [((), df)] for keys, data in groups: @@ -1634,13 +1611,15 @@ class DistributionBarViz(DistributionPieViz): is_timeseries = False def query_obj(self): - # TODO: Refactor this plugin to either perform grouping or assume - # preaggretagion of metrics ("numeric columns") d = super().query_obj() fd = self.form_data - if not self.all_metrics: + if len(d["groupby"]) < len(fd.get("groupby") or []) + len( + fd.get("columns") or [] + ): + raise Exception(_("Can't have overlap between Series and Breakdowns")) + if not fd.get("metrics"): raise Exception(_("Pick at least one metric")) - if not self.columns: + if not fd.get("groupby"): raise Exception(_("Pick at least one field for [Series]")) return d @@ -1650,25 +1629,23 @@ def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data metrics = self.metric_labels - # TODO: will require post transformation logic not currently available in - # /api/v1/query endpoint columns = fd.get("columns") or [] - groupby = fd.get("groupby") or [] # pandas will throw away nulls when grouping/pivoting, # so we substitute NULL_STRING for any nulls in the necessary columns - df[self.columns] = df[self.columns].fillna(value=NULL_STRING) + filled_cols = self.groupby + columns + df[filled_cols] = df[filled_cols].fillna(value=NULL_STRING) - row = df.groupby(groupby).sum()[metrics[0]].copy() + row = df.groupby(self.groupby).sum()[metrics[0]].copy() row.sort_values(ascending=False, inplace=True) - pt = df.pivot_table(index=groupby, columns=columns, values=metrics) + pt = df.pivot_table(index=self.groupby, columns=columns, values=metrics) if fd.get("contribution"): pt = pt.T pt = (pt / pt.sum()).T pt = pt.reindex(row.index) chart_data = [] for name, ys in pt.items(): - if pt[name].dtype.kind not in "biufc" or name in groupby: + if pt[name].dtype.kind not in "biufc" or name in self.groupby: continue if isinstance(name, str): series_title = name @@ -1737,6 +1714,13 @@ class SankeyViz(BaseViz): is_timeseries = False credits = 'd3-sankey on npm' + def query_obj(self): + qry = super().query_obj() + if len(qry["groupby"]) != 2: + raise Exception(_("Pick exactly 2 columns as [Source / Target]")) + qry["metrics"] = [self.form_data["metric"]] + return qry + def get_data(self, df: pd.DataFrame) -> VizData: df.columns = ["source", "target", "value"] df["source"] = df["source"].astype(str) @@ -1807,6 +1791,7 @@ class ChordViz(BaseViz): def query_obj(self): qry = super().query_obj() fd = self.form_data + qry["groupby"] = [fd.get("groupby"), fd.get("columns")] qry["metrics"] = [fd.get("metric")] return qry @@ -1836,6 +1821,12 @@ class CountryMapViz(BaseViz): is_timeseries = False credits = "From bl.ocks.org By john-guerra" + def query_obj(self): + qry = super().query_obj() + qry["metrics"] = [self.form_data["metric"]] + qry["groupby"] = [self.form_data["entity"]] + return qry + def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data cols = [fd.get("entity")] @@ -1857,6 +1848,11 @@ class WorldMapViz(BaseViz): is_timeseries = False credits = 'datamaps on npm' + def query_obj(self): + qry = super().query_obj() + qry["groupby"] = [self.form_data["entity"]] + return qry + def get_data(self, df: pd.DataFrame) -> VizData: from superset.examples import countries @@ -1919,7 +1915,7 @@ def run_extra_queries(self): raise Exception( _("Invalid filter configuration, please select a column") ) - qry["columns"] = [col] + qry["groupby"] = [col] metric = flt.get("metric") qry["metrics"] = [metric] if metric else [] df = self.get_df_payload(query_obj=qry).get("df") @@ -1985,6 +1981,12 @@ class ParallelCoordinatesViz(BaseViz): ) is_timeseries = False + def query_obj(self): + d = super().query_obj() + fd = self.form_data + d["groupby"] = [fd.get("series")] + return d + def get_data(self, df: pd.DataFrame) -> VizData: return df.to_dict(orient="records") @@ -2001,6 +2003,13 @@ class HeatmapViz(BaseViz): "bl.ocks.org" ) + def query_obj(self): + d = super().query_obj() + fd = self.form_data + d["metrics"] = [fd.get("metric")] + d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] + return d + def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None @@ -2218,6 +2227,21 @@ def get_metrics(self): self.metric = self.form_data.get("size") return [self.metric] if self.metric else [] + def process_spatial_query_obj(self, key, group_by): + group_by.extend(self.get_spatial_columns(key)) + + def get_spatial_columns(self, key): + spatial = self.form_data.get(key) + if spatial is None: + raise ValueError(_("Bad spatial key")) + + if spatial.get("type") == "latlong": + return [spatial.get("lonCol"), spatial.get("latCol")] + elif spatial.get("type") == "delimited": + return [spatial.get("lonlatCol")] + elif spatial.get("type") == "geohash": + return [spatial.get("geohashCol")] + @staticmethod def parse_coordinates(s): if not s: @@ -2272,6 +2296,9 @@ def process_spatial_data_obj(self, key, df): def add_null_filters(self): fd = self.form_data spatial_columns = set() + for key in self.spatial_control_keys: + for column in self.get_spatial_columns(key): + spatial_columns.add(column) if fd.get("adhoc_filters") is None: fd["adhoc_filters"] = [] @@ -2294,6 +2321,9 @@ def query_obj(self): d = super().query_obj() gb = [] + for key in self.spatial_control_keys: + self.process_spatial_query_obj(key, gb) + if fd.get("dimension"): gb += [fd.get("dimension")] @@ -2302,7 +2332,11 @@ def query_obj(self): metrics = self.get_metrics() gb = list(set(gb)) if metrics: + d["groupby"] = gb d["metrics"] = metrics + d["columns"] = [] + else: + d["columns"] = gb return d def get_js_columns(self, d): @@ -2454,10 +2488,13 @@ def query_obj(self): self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") d = super().query_obj() self.metric = fd.get("metric") + line_col = fd.get("line_column") if d["metrics"]: self.has_metrics = True + d["groupby"].append(line_col) else: self.has_metrics = False + d["columns"].append(line_col) return d def get_properties(self, d): @@ -2536,6 +2573,7 @@ def query_obj(self): d = super().query_obj() d["columns"] += [self.form_data.get("geojson")] d["metrics"] = [] + d["groupby"] = [] return d def get_properties(self, d): diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py new file mode 100644 index 000000000000..9556aaa1550c --- /dev/null +++ b/superset/viz_sip38.py @@ -0,0 +1,2857 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=C,R,W +"""This module contains the 'Viz' objects + +These objects represent the backend of all the visualizations that +Superset can render. +""" +import copy +import hashlib +import inspect +import logging +import math +import pickle as pkl +import re +import uuid +from collections import defaultdict, OrderedDict +from datetime import datetime, timedelta +from itertools import product +from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING + +import geohash +import numpy as np +import pandas as pd +import polyline +import simplejson as json +from dateutil import relativedelta as rdelta +from flask import request +from flask_babel import lazy_gettext as _ +from geopy.point import Point +from markdown import markdown +from pandas.tseries.frequencies import to_offset + +from superset import app, cache, get_manifest_files, security_manager +from superset.constants import NULL_STRING +from superset.exceptions import NullValueException, SpatialException +from superset.models.helpers import QueryResult +from superset.typing import VizData +from superset.utils import core as utils +from superset.utils.core import ( + DTTM_ALIAS, + JS_MAX_INTEGER, + merge_extra_filters, + to_adhoc, +) + +if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource + +config = app.config +stats_logger = config["STATS_LOGGER"] +relative_start = config["DEFAULT_RELATIVE_START_TIME"] +relative_end = config["DEFAULT_RELATIVE_END_TIME"] +logger = logging.getLogger(__name__) + +METRIC_KEYS = [ + "metric", + "metrics", + "percent_metrics", + "metric_2", + "secondary_metric", + "x", + "y", + "size", +] + +COLUMN_FORM_DATA_PARAMS = [ + "all_columns", + "all_columns_x", + "all_columns_y", + "columns", + "dimension", + "entity", + "groupby", + "series", + "line_column", + "js_columns", +] + +SPATIAL_COLUMN_FORM_DATA_PARAMS = ["spatial", "start_spatial", "end_spatial"] + + +class BaseViz: + + """All visualizations derive this base class""" + + viz_type: Optional[str] = None + verbose_name = "Base Viz" + credits = "" + is_timeseries = False + cache_type = "df" + enforce_numerical_metrics = True + + def __init__( + self, + datasource: "BaseDatasource", + form_data: Dict[str, Any], + force: bool = False, + ): + if not datasource: + raise Exception(_("Viz is missing a datasource")) + + self.datasource = datasource + self.request = request + self.viz_type = form_data.get("viz_type") + self.form_data = form_data + + self.query = "" + self.token = self.form_data.get("token", "token_" + uuid.uuid4().hex[:8]) + + # merge all selectable columns into `columns` property + self.columns: List[str] = [] + for key in COLUMN_FORM_DATA_PARAMS: + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + value = self.form_data.get(key) or [] + value_list = value if isinstance(value, list) else [value] + self.columns += value_list + + for key in SPATIAL_COLUMN_FORM_DATA_PARAMS: + spatial = self.form_data.get(key) + if not isinstance(spatial, dict): + continue + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + if spatial.get("type") == "latlong": + self.columns += [spatial["lonCol"], spatial["latCol"]] + elif spatial.get("type") == "delimited": + self.columns.append(spatial["lonlatCol"]) + elif spatial.get("type") == "geohash": + self.columns.append(spatial["geohashCol"]) + + self.time_shift = timedelta() + + self.status: Optional[str] = None + self.error_msg = "" + self.results: Optional[QueryResult] = None + self.error_message: Optional[str] = None + self.force = force + + # Keeping track of whether some data came from cache + # this is useful to trigger the when + # in the cases where visualization have many queries + # (FilterBox for instance) + self._any_cache_key: Optional[str] = None + self._any_cached_dttm: Optional[str] = None + self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = [] + + self.process_metrics() + + def process_metrics(self): + # metrics in TableViz is order sensitive, so metric_dict should be + # OrderedDict + self.metric_dict = OrderedDict() + fd = self.form_data + for mkey in METRIC_KEYS: + val = fd.get(mkey) + if val: + if not isinstance(val, list): + val = [val] + for o in val: + label = utils.get_metric_name(o) + self.metric_dict[label] = o + + # Cast to list needed to return serializable object in py3 + self.all_metrics = list(self.metric_dict.values()) + self.metric_labels = list(self.metric_dict.keys()) + + @staticmethod + def handle_js_int_overflow(data): + for d in data.get("records", dict()): + for k, v in list(d.items()): + if isinstance(v, int): + # if an int is too big for Java Script to handle + # convert it to a string + if abs(v) > JS_MAX_INTEGER: + d[k] = str(v) + return data + + def run_extra_queries(self): + """Lifecycle method to use when more than one query is needed + + In rare-ish cases, a visualization may need to execute multiple + queries. That is the case for FilterBox or for time comparison + in Line chart for instance. + + In those cases, we need to make sure these queries run before the + main `get_payload` method gets called, so that the overall caching + metadata can be right. The way it works here is that if any of + the previous `get_df_payload` calls hit the cache, the main + payload's metadata will reflect that. + + The multi-query support may need more work to become a first class + use case in the framework, and for the UI to reflect the subtleties + (show that only some of the queries were served from cache for + instance). In the meantime, since multi-query is rare, we treat + it with a bit of a hack. Note that the hack became necessary + when moving from caching the visualization's data itself, to caching + the underlying query(ies). + """ + pass + + def apply_rolling(self, df): + fd = self.form_data + rolling_type = fd.get("rolling_type") + rolling_periods = int(fd.get("rolling_periods") or 0) + min_periods = int(fd.get("min_periods") or 0) + + if rolling_type in ("mean", "std", "sum") and rolling_periods: + kwargs = dict(window=rolling_periods, min_periods=min_periods) + if rolling_type == "mean": + df = df.rolling(**kwargs).mean() + elif rolling_type == "std": + df = df.rolling(**kwargs).std() + elif rolling_type == "sum": + df = df.rolling(**kwargs).sum() + elif rolling_type == "cumsum": + df = df.cumsum() + if min_periods: + df = df[min_periods:] + return df + + def get_samples(self): + query_obj = self.query_obj() + query_obj.update( + { + "metrics": [], + "row_limit": 1000, + "columns": [o.column_name for o in self.datasource.columns], + } + ) + df = self.get_df(query_obj) + return df.to_dict(orient="records") + + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + """Returns a pandas dataframe based on the query object""" + if not query_obj: + query_obj = self.query_obj() + if not query_obj: + return pd.DataFrame() + + self.error_msg = "" + + timestamp_format = None + if self.datasource.type == "table": + granularity_col = self.datasource.get_column(query_obj["granularity"]) + if granularity_col: + timestamp_format = granularity_col.python_date_format + + # The datasource here can be different backend but the interface is common + self.results = self.datasource.query(query_obj) + self.query = self.results.query + self.status = self.results.status + self.error_message = self.results.error_message + + df = self.results.df + # Transform the timestamp we received from database to pandas supported + # datetime format. If no python_date_format is specified, the pattern will + # be considered as the default ISO date format + # If the datetime format is unix, the parse will use the corresponding + # parsing logic. + if not df.empty: + if DTTM_ALIAS in df.columns: + if timestamp_format in ("epoch_s", "epoch_ms"): + # Column has already been formatted as a timestamp. + dttm_col = df[DTTM_ALIAS] + one_ts_val = dttm_col[0] + + # convert time column to pandas Timestamp, but different + # ways to convert depending on string or int types + try: + int(one_ts_val) + is_integral = True + except (ValueError, TypeError): + is_integral = False + if is_integral: + unit = "s" if timestamp_format == "epoch_s" else "ms" + df[DTTM_ALIAS] = pd.to_datetime( + dttm_col, utc=False, unit=unit, origin="unix" + ) + else: + df[DTTM_ALIAS] = dttm_col.apply(pd.Timestamp) + else: + df[DTTM_ALIAS] = pd.to_datetime( + df[DTTM_ALIAS], utc=False, format=timestamp_format + ) + if self.datasource.offset: + df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset) + df[DTTM_ALIAS] += self.time_shift + + if self.enforce_numerical_metrics: + self.df_metrics_to_num(df) + + df.replace([np.inf, -np.inf], np.nan, inplace=True) + return df + + def df_metrics_to_num(self, df): + """Converting metrics to numeric when pandas.read_sql cannot""" + metrics = self.metric_labels + for col, dtype in df.dtypes.items(): + if dtype.type == np.object_ and col in metrics: + df[col] = pd.to_numeric(df[col], errors="coerce") + + def process_query_filters(self): + utils.convert_legacy_filters_into_adhoc(self.form_data) + merge_extra_filters(self.form_data) + utils.split_adhoc_filters_into_base_filters(self.form_data) + + def query_obj(self) -> Dict[str, Any]: + """Building a query object""" + form_data = self.form_data + self.process_query_filters() + metrics = self.all_metrics or [] + columns = self.columns + + is_timeseries = self.is_timeseries + if DTTM_ALIAS in columns: + columns.remove(DTTM_ALIAS) + is_timeseries = True + + granularity = form_data.get("granularity") or form_data.get("granularity_sqla") + limit = int(form_data.get("limit") or 0) + timeseries_limit_metric = form_data.get("timeseries_limit_metric") + row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"]) + + # default order direction + order_desc = form_data.get("order_desc", True) + + since, until = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=form_data.get("time_range"), + since=form_data.get("since"), + until=form_data.get("until"), + ) + time_shift = form_data.get("time_shift", "") + self.time_shift = utils.parse_past_timedelta(time_shift) + from_dttm = None if since is None else (since - self.time_shift) + to_dttm = None if until is None else (until - self.time_shift) + if from_dttm and to_dttm and from_dttm > to_dttm: + raise Exception(_("From date cannot be larger than to date")) + + self.from_dttm = from_dttm + self.to_dttm = to_dttm + + # extras are used to query elements specific to a datasource type + # for instance the extra where clause that applies only to Tables + extras = { + "druid_time_origin": form_data.get("druid_time_origin", ""), + "having": form_data.get("having", ""), + "having_druid": form_data.get("having_filters", []), + "time_grain_sqla": form_data.get("time_grain_sqla", ""), + "time_range_endpoints": form_data.get("time_range_endpoints"), + "where": form_data.get("where", ""), + } + + d = { + "granularity": granularity, + "from_dttm": from_dttm, + "to_dttm": to_dttm, + "is_timeseries": is_timeseries, + "columns": columns, + "metrics": metrics, + "row_limit": row_limit, + "filter": self.form_data.get("filters", []), + "timeseries_limit": limit, + "extras": extras, + "timeseries_limit_metric": timeseries_limit_metric, + "order_desc": order_desc, + } + return d + + @property + def cache_timeout(self): + if self.form_data.get("cache_timeout") is not None: + return int(self.form_data.get("cache_timeout")) + if self.datasource.cache_timeout is not None: + return self.datasource.cache_timeout + if ( + hasattr(self.datasource, "database") + and self.datasource.database.cache_timeout + ) is not None: + return self.datasource.database.cache_timeout + return config["CACHE_DEFAULT_TIMEOUT"] + + def get_json(self): + return json.dumps( + self.get_payload(), default=utils.json_int_dttm_ser, ignore_nan=True + ) + + def cache_key(self, query_obj, **extra): + """ + The cache key is made out of the key/values in `query_obj`, plus any + other key/values in `extra`. + + We remove datetime bounds that are hard values, and replace them with + the use-provided inputs to bounds, which may be time-relative (as in + "5 days ago" or "now"). + + The `extra` arguments are currently used by time shift queries, since + different time shifts wil differ only in the `from_dttm` and `to_dttm` + values which are stripped. + """ + cache_dict = copy.copy(query_obj) + cache_dict.update(extra) + + for k in ["from_dttm", "to_dttm"]: + del cache_dict[k] + + cache_dict["time_range"] = self.form_data.get("time_range") + cache_dict["datasource"] = self.datasource.uid + cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj) + cache_dict["rls"] = security_manager.get_rls_ids(self.datasource) + cache_dict["changed_on"] = self.datasource.changed_on + json_data = self.json_dumps(cache_dict, sort_keys=True) + return hashlib.md5(json_data.encode("utf-8")).hexdigest() + + def get_payload(self, query_obj=None): + """Returns a payload of metadata and data""" + self.run_extra_queries() + payload = self.get_df_payload(query_obj) + + df = payload.get("df") + if self.status != utils.QueryStatus.FAILED: + payload["data"] = self.get_data(df) + if "df" in payload: + del payload["df"] + return payload + + def get_df_payload(self, query_obj=None, **kwargs): + """Handles caching around the df payload retrieval""" + if not query_obj: + query_obj = self.query_obj() + cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None + logger.info("Cache key: {}".format(cache_key)) + is_loaded = False + stacktrace = None + df = None + cached_dttm = datetime.utcnow().isoformat().split(".")[0] + if cache_key and cache and not self.force: + cache_value = cache.get(cache_key) + if cache_value: + stats_logger.incr("loading_from_cache") + try: + cache_value = pkl.loads(cache_value) + df = cache_value["df"] + self.query = cache_value["query"] + self._any_cached_dttm = cache_value["dttm"] + self._any_cache_key = cache_key + self.status = utils.QueryStatus.SUCCESS + is_loaded = True + stats_logger.incr("loaded_from_cache") + except Exception as e: + logger.exception(e) + logger.error( + "Error reading cache: " + utils.error_msg_from_exception(e) + ) + logger.info("Serving from cache") + + if query_obj and not is_loaded: + try: + df = self.get_df(query_obj) + if self.status != utils.QueryStatus.FAILED: + stats_logger.incr("loaded_from_source") + if not self.force: + stats_logger.incr("loaded_from_source_without_force") + is_loaded = True + except Exception as e: + logger.exception(e) + if not self.error_message: + self.error_message = "{}".format(e) + self.status = utils.QueryStatus.FAILED + stacktrace = utils.get_stacktrace() + + if ( + is_loaded + and cache_key + and cache + and self.status != utils.QueryStatus.FAILED + ): + try: + cache_value = dict(dttm=cached_dttm, df=df, query=self.query) + cache_value = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL) + + logger.info( + "Caching {} chars at key {}".format(len(cache_value), cache_key) + ) + + stats_logger.incr("set_cache_key") + cache.set(cache_key, cache_value, timeout=self.cache_timeout) + except Exception as e: + # cache.set call can fail if the backend is down or if + # the key is too large or whatever other reasons + logger.warning("Could not cache key {}".format(cache_key)) + logger.exception(e) + cache.delete(cache_key) + return { + "cache_key": self._any_cache_key, + "cached_dttm": self._any_cached_dttm, + "cache_timeout": self.cache_timeout, + "df": df, + "error": self.error_message, + "form_data": self.form_data, + "is_cached": self._any_cache_key is not None, + "query": self.query, + "status": self.status, + "stacktrace": stacktrace, + "rowcount": len(df.index) if df is not None else 0, + } + + def json_dumps(self, obj, sort_keys=False): + return json.dumps( + obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys + ) + + def payload_json_and_has_error(self, payload): + has_error = ( + payload.get("status") == utils.QueryStatus.FAILED + or payload.get("error") is not None + ) + return self.json_dumps(payload), has_error + + @property + def data(self): + """This is the data object serialized to the js layer""" + content = { + "form_data": self.form_data, + "token": self.token, + "viz_name": self.viz_type, + "filter_select_enabled": self.datasource.filter_select_enabled, + } + return content + + def get_csv(self): + df = self.get_df() + include_index = not isinstance(df.index, pd.RangeIndex) + return df.to_csv(index=include_index, **config["CSV_EXPORT"]) + + def get_data(self, df: pd.DataFrame) -> VizData: + return df.to_dict(orient="records") + + @property + def json_data(self): + return json.dumps(self.data) + + +class TableViz(BaseViz): + + """A basic html table that is sortable and searchable""" + + viz_type = "table" + verbose_name = _("Table View") + credits = 'a Superset original' + is_timeseries = False + enforce_numerical_metrics = False + + def should_be_timeseries(self): + fd = self.form_data + # TODO handle datasource-type-specific code in datasource + conditions_met = (fd.get("granularity") and fd.get("granularity") != "all") or ( + fd.get("granularity_sqla") and fd.get("time_grain_sqla") + ) + if fd.get("include_time") and not conditions_met: + raise Exception( + _("Pick a granularity in the Time section or " "uncheck 'Include Time'") + ) + return fd.get("include_time") + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + + if fd.get("all_columns") and ( + fd.get("groupby") or fd.get("metrics") or fd.get("percent_metrics") + ): + raise Exception( + _( + "Choose either fields to [Group By] and [Metrics] and/or " + "[Percentage Metrics], or [Columns], not both" + ) + ) + + sort_by = fd.get("timeseries_limit_metric") + if fd.get("all_columns"): + order_by_cols = fd.get("order_by_cols") or [] + d["orderby"] = [json.loads(t) for t in order_by_cols] + elif sort_by: + sort_by_label = utils.get_metric_name(sort_by) + if sort_by_label not in utils.get_metric_names(d["metrics"]): + d["metrics"] += [sort_by] + d["orderby"] = [(sort_by, not fd.get("order_desc", True))] + + # Add all percent metrics that are not already in the list + if "percent_metrics" in fd: + d["metrics"].extend( + m for m in fd["percent_metrics"] or [] if m not in d["metrics"] + ) + + d["is_timeseries"] = self.should_be_timeseries() + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + """ + Transform the query result to the table representation. + + :param df: The interim dataframe + :returns: The table visualization data + + The interim dataframe comprises of the group-by and non-group-by columns and + the union of the metrics representing the non-percent and percent metrics. Note + the percent metrics have yet to be transformed. + """ + + non_percent_metric_columns = [] + # Transform the data frame to adhere to the UI ordering of the columns and + # metrics whilst simultaneously computing the percentages (via normalization) + # for the percent metrics. + + if DTTM_ALIAS in df: + if self.should_be_timeseries(): + non_percent_metric_columns.append(DTTM_ALIAS) + else: + del df[DTTM_ALIAS] + + non_percent_metric_columns.extend( + self.form_data.get("all_columns") or self.form_data.get("groupby") or [] + ) + + non_percent_metric_columns.extend( + utils.get_metric_names(self.form_data.get("metrics") or []) + ) + + timeseries_limit_metric = utils.get_metric_name( + self.form_data.get("timeseries_limit_metric") + ) + if timeseries_limit_metric: + non_percent_metric_columns.append(timeseries_limit_metric) + + percent_metric_columns = utils.get_metric_names( + self.form_data.get("percent_metrics") or [] + ) + + df = pd.concat( + [ + df[non_percent_metric_columns], + ( + df[percent_metric_columns] + .div(df[percent_metric_columns].sum()) + .add_prefix("%") + ), + ], + axis=1, + ) + + data = self.handle_js_int_overflow( + dict(records=df.to_dict(orient="records"), columns=list(df.columns)) + ) + + return data + + def json_dumps(self, obj, sort_keys=False): + return json.dumps( + obj, default=utils.json_iso_dttm_ser, sort_keys=sort_keys, ignore_nan=True + ) + + +class TimeTableViz(BaseViz): + + """A data table with rich time-series related columns""" + + viz_type = "time_table" + verbose_name = _("Time Table View") + credits = 'a Superset original' + is_timeseries = True + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + + if not fd.get("metrics"): + raise Exception(_("Pick at least one metric")) + + if fd.get("groupby") and len(fd.get("metrics")) > 1: + raise Exception( + _("When using 'Group By' you are limited to use a single metric") + ) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + columns = None + values = self.metric_labels + if fd.get("groupby"): + values = self.metric_labels[0] + columns = fd.get("groupby") + pt = df.pivot_table(index=DTTM_ALIAS, columns=columns, values=values) + pt.index = pt.index.map(str) + pt = pt.sort_index() + return dict( + records=pt.to_dict(orient="index"), + columns=list(pt.columns), + is_group_by=len(fd.get("groupby", [])) > 0, + ) + + +class PivotTableViz(BaseViz): + + """A pivot table view, define your rows, columns and metrics""" + + viz_type = "pivot_table" + verbose_name = _("Pivot Table") + credits = 'a Superset original' + is_timeseries = False + + def query_obj(self): + d = super().query_obj() + groupby = self.form_data.get("groupby") + columns = self.form_data.get("columns") + metrics = self.form_data.get("metrics") + transpose = self.form_data.get("transpose_pivot") + if not columns: + columns = [] + if not groupby: + groupby = [] + if not groupby: + raise Exception(_("Please choose at least one 'Group by' field ")) + if transpose and not columns: + raise Exception( + _( + ( + "Please choose at least one 'Columns' field when " + "select 'Transpose Pivot' option" + ) + ) + ) + if not metrics: + raise Exception(_("Please choose at least one metric")) + if set(groupby) & set(columns): + raise Exception(_("Group By' and 'Columns' can't overlap")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df: + del df[DTTM_ALIAS] + + aggfunc = self.form_data.get("pandas_aggfunc") or "sum" + + # Ensure that Pandas's sum function mimics that of SQL. + if aggfunc == "sum": + aggfunc = lambda x: x.sum(min_count=1) + + groupby = self.form_data.get("groupby") + columns = self.form_data.get("columns") + if self.form_data.get("transpose_pivot"): + groupby, columns = columns, groupby + metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]] + df = df.pivot_table( + index=groupby, + columns=columns, + values=metrics, + aggfunc=aggfunc, + margins=self.form_data.get("pivot_margins"), + ) + + # Re-order the columns adhering to the metric ordering. + df = df[metrics] + + # Display metrics side by side with each column + if self.form_data.get("combine_metric"): + df = df.stack(0).unstack() + return dict( + columns=list(df.columns), + html=df.to_html( + na_rep="null", + classes=( + "dataframe table table-striped table-bordered " + "table-condensed table-hover" + ).split(" "), + ), + ) + + +class MarkupViz(BaseViz): + + """Use html or markdown to create a free form widget""" + + viz_type = "markup" + verbose_name = _("Markup") + is_timeseries = False + + def query_obj(self): + return None + + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + return pd.DataFrame() + + def get_data(self, df: pd.DataFrame) -> VizData: + markup_type = self.form_data.get("markup_type") + code = self.form_data.get("code", "") + if markup_type == "markdown": + code = markdown(code) + return dict(html=code, theme_css=get_manifest_files("theme", "css")) + + +class SeparatorViz(MarkupViz): + + """Use to create section headers in a dashboard, similar to `Markup`""" + + viz_type = "separator" + verbose_name = _("Separator") + + +class WordCloudViz(BaseViz): + + """Build a colorful word cloud + + Uses the nice library at: + https://github.com/jasondavies/d3-cloud + """ + + viz_type = "word_cloud" + verbose_name = _("Word Cloud") + is_timeseries = False + + +class TreemapViz(BaseViz): + + """Tree map visualisation for hierarchical data.""" + + viz_type = "treemap" + verbose_name = _("Treemap") + credits = 'd3.js' + is_timeseries = False + + def _nest(self, metric, df): + nlevels = df.index.nlevels + if nlevels == 1: + result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])] + else: + result = [ + {"name": l, "children": self._nest(metric, df.loc[l])} + for l in df.index.levels[0] + ] + return result + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + df = df.set_index(self.form_data.get("groupby")) + chart_data = [ + {"name": metric, "children": self._nest(metric, df)} + for metric in df.columns + ] + return chart_data + + +class CalHeatmapViz(BaseViz): + + """Calendar heatmap.""" + + viz_type = "cal_heatmap" + verbose_name = _("Calendar Heatmap") + credits = "cal-heatmap" + is_timeseries = True + + def get_data(self, df: pd.DataFrame) -> VizData: + form_data = self.form_data + + data = {} + records = df.to_dict("records") + for metric in self.metric_labels: + values = {} + for obj in records: + v = obj[DTTM_ALIAS] + if hasattr(v, "value"): + v = v.value + values[str(v / 10 ** 9)] = obj.get(metric) + data[metric] = values + + start, end = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=form_data.get("time_range"), + since=form_data.get("since"), + until=form_data.get("until"), + ) + if not start or not end: + raise Exception("Please provide both time bounds (Since and Until)") + domain = form_data.get("domain_granularity") + diff_delta = rdelta.relativedelta(end, start) + diff_secs = (end - start).total_seconds() + + if domain == "year": + range_ = diff_delta.years + 1 + elif domain == "month": + range_ = diff_delta.years * 12 + diff_delta.months + 1 + elif domain == "week": + range_ = diff_delta.years * 53 + diff_delta.weeks + 1 + elif domain == "day": + range_ = diff_secs // (24 * 60 * 60) + 1 # type: ignore + else: + range_ = diff_secs // (60 * 60) + 1 # type: ignore + + return { + "data": data, + "start": start, + "domain": domain, + "subdomain": form_data.get("subdomain_granularity"), + "range": range_, + } + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + d["metrics"] = fd.get("metrics") + return d + + +class NVD3Viz(BaseViz): + + """Base class for all nvd3 vizs""" + + credits = 'NVD3.org' + viz_type: Optional[str] = None + verbose_name = "Base NVD3 Viz" + is_timeseries = False + + +class BoxPlotViz(NVD3Viz): + + """Box plot viz from ND3""" + + viz_type = "box_plot" + verbose_name = _("Box Plot") + sort_series = False + is_timeseries = True + + def to_series(self, df, classed="", title_suffix=""): + label_sep = " - " + chart_data = [] + for index_value, row in zip(df.index, df.to_dict(orient="records")): + if isinstance(index_value, tuple): + index_value = label_sep.join(index_value) + boxes = defaultdict(dict) + for (label, key), value in row.items(): + if key == "nanmedian": + key = "Q2" + boxes[label][key] = value + for label, box in boxes.items(): + if len(self.form_data.get("metrics")) > 1: + # need to render data labels with metrics + chart_label = label_sep.join([index_value, label]) + else: + chart_label = index_value + chart_data.append({"label": chart_label, "values": box}) + return chart_data + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + form_data = self.form_data + + # conform to NVD3 names + def Q1(series): # need to be named functions - can't use lambdas + return np.nanpercentile(series, 25) + + def Q3(series): + return np.nanpercentile(series, 75) + + whisker_type = form_data.get("whisker_options") + if whisker_type == "Tukey": + + def whisker_high(series): + upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series)) + return series[series <= upper_outer_lim].max() + + def whisker_low(series): + lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series)) + return series[series >= lower_outer_lim].min() + + elif whisker_type == "Min/max (no outliers)": + + def whisker_high(series): + return series.max() + + def whisker_low(series): + return series.min() + + elif " percentiles" in whisker_type: # type: ignore + low, high = whisker_type.replace(" percentiles", "").split( # type: ignore + "/" + ) + + def whisker_high(series): + return np.nanpercentile(series, int(high)) + + def whisker_low(series): + return np.nanpercentile(series, int(low)) + + else: + raise ValueError("Unknown whisker type: {}".format(whisker_type)) + + def outliers(series): + above = series[series > whisker_high(series)] + below = series[series < whisker_low(series)] + # pandas sometimes doesn't like getting lists back here + return set(above.tolist() + below.tolist()) + + aggregate = [Q1, np.nanmedian, Q3, whisker_high, whisker_low, outliers] + df = df.groupby(form_data.get("groupby")).agg(aggregate) + chart_data = self.to_series(df) + return chart_data + + +class BubbleViz(NVD3Viz): + + """Based on the NVD3 bubble chart""" + + viz_type = "bubble" + verbose_name = _("Bubble Chart") + is_timeseries = False + + def query_obj(self): + form_data = self.form_data + d = super().query_obj() + + self.x_metric = form_data.get("x") + self.y_metric = form_data.get("y") + self.z_metric = form_data.get("size") + self.entity = form_data.get("entity") + self.series = form_data.get("series") or self.entity + d["row_limit"] = form_data.get("limit") + + d["metrics"] = [self.z_metric, self.x_metric, self.y_metric] + if len(set(self.metric_labels)) < 3: + raise Exception(_("Please use 3 different metric labels")) + if not all(d["metrics"] + [self.entity]): + raise Exception(_("Pick a metric for x, y and size")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + df["x"] = df[[utils.get_metric_name(self.x_metric)]] + df["y"] = df[[utils.get_metric_name(self.y_metric)]] + df["size"] = df[[utils.get_metric_name(self.z_metric)]] + df["shape"] = "circle" + df["group"] = df[[self.series]] + + series: Dict[Any, List[Any]] = defaultdict(list) + for row in df.to_dict(orient="records"): + series[row["group"]].append(row) + chart_data = [] + for k, v in series.items(): + chart_data.append({"key": k, "values": v}) + return chart_data + + +class BulletViz(NVD3Viz): + + """Based on the NVD3 bullet chart""" + + viz_type = "bullet" + verbose_name = _("Bullet Chart") + is_timeseries = False + + def query_obj(self): + form_data = self.form_data + d = super().query_obj() + self.metric = form_data.get("metric") + + def as_strings(field): + value = form_data.get(field) + return value.split(",") if value else [] + + def as_floats(field): + return [float(x) for x in as_strings(field)] + + self.ranges = as_floats("ranges") + self.range_labels = as_strings("range_labels") + self.markers = as_floats("markers") + self.marker_labels = as_strings("marker_labels") + self.marker_lines = as_floats("marker_lines") + self.marker_line_labels = as_strings("marker_line_labels") + + d["metrics"] = [self.metric] + if not self.metric: + raise Exception(_("Pick a metric to display")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + df["metric"] = df[[utils.get_metric_name(self.metric)]] + values = df["metric"].values + return { + "measures": values.tolist(), + "ranges": self.ranges or [0, values.max() * 1.1], + "rangeLabels": self.range_labels or None, + "markers": self.markers or None, + "markerLabels": self.marker_labels or None, + "markerLines": self.marker_lines or None, + "markerLineLabels": self.marker_line_labels or None, + } + + +class BigNumberViz(BaseViz): + + """Put emphasis on a single metric with this big number viz""" + + viz_type = "big_number" + verbose_name = _("Big Number with Trendline") + credits = 'a Superset original' + is_timeseries = True + + def query_obj(self): + d = super().query_obj() + metric = self.form_data.get("metric") + if not metric: + raise Exception(_("Pick a metric!")) + d["metrics"] = [self.form_data.get("metric")] + self.form_data["metric"] = metric + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=[], + values=self.metric_labels, + dropna=False, + aggfunc=np.min, # looking for any (only) value, preserving `None` + ) + df = self.apply_rolling(df) + df[DTTM_ALIAS] = df.index + return super().get_data(df) + + +class BigNumberTotalViz(BaseViz): + + """Put emphasis on a single metric with this big number viz""" + + viz_type = "big_number_total" + verbose_name = _("Big Number") + credits = 'a Superset original' + is_timeseries = False + + def query_obj(self): + d = super().query_obj() + metric = self.form_data.get("metric") + if not metric: + raise Exception(_("Pick a metric!")) + d["metrics"] = [self.form_data.get("metric")] + self.form_data["metric"] = metric + + # Limiting rows is not required as only one cell is returned + d["row_limit"] = None + return d + + +class NVD3TimeSeriesViz(NVD3Viz): + + """A rich line chart component with tons of options""" + + viz_type = "line" + verbose_name = _("Time Series - Line Chart") + sort_series = False + is_timeseries = True + pivot_fill_value: Optional[int] = None + + def to_series(self, df, classed="", title_suffix=""): + cols = [] + for col in df.columns: + if col == "": + cols.append("N/A") + elif col is None: + cols.append("NULL") + else: + cols.append(col) + df.columns = cols + series = df.to_dict("series") + + chart_data = [] + for name in df.T.index.tolist(): + ys = series[name] + if df[name].dtype.kind not in "biufc": + continue + if isinstance(name, list): + series_title = [str(title) for title in name] + elif isinstance(name, tuple): + series_title = tuple(str(title) for title in name) + else: + series_title = str(name) + if ( + isinstance(series_title, (list, tuple)) + and len(series_title) > 1 + and len(self.metric_labels) == 1 + ): + # Removing metric from series name if only one metric + series_title = series_title[1:] + if title_suffix: + if isinstance(series_title, str): + series_title = (series_title, title_suffix) + elif isinstance(series_title, (list, tuple)): + series_title = series_title + (title_suffix,) + + values = [] + non_nan_cnt = 0 + for ds in df.index: + if ds in ys: + d = {"x": ds, "y": ys[ds]} + if not np.isnan(ys[ds]): + non_nan_cnt += 1 + else: + d = {} + values.append(d) + + if non_nan_cnt == 0: + continue + + d = {"key": series_title, "values": values} + if classed: + d["classed"] = classed + chart_data.append(d) + return chart_data + + def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: + fd = self.form_data + if fd.get("granularity") == "all": + raise Exception(_("Pick a time granularity for your time series")) + + if df.empty: + return df + + if aggregate: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=self.columns, + values=self.metric_labels, + fill_value=0, + aggfunc=sum, + ) + else: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=self.columns, + values=self.metric_labels, + fill_value=self.pivot_fill_value, + ) + + rule = fd.get("resample_rule") + method = fd.get("resample_method") + + if rule and method: + df = getattr(df.resample(rule), method)() + + if self.sort_series: + dfs = df.sum() + dfs.sort_values(ascending=False, inplace=True) + df = df[dfs.index] + + df = self.apply_rolling(df) + if fd.get("contribution"): + dft = df.T + df = (dft / dft.sum()).T + + return df + + def run_extra_queries(self): + fd = self.form_data + + time_compare = fd.get("time_compare") or [] + # backwards compatibility + if not isinstance(time_compare, list): + time_compare = [time_compare] + + for option in time_compare: + query_object = self.query_obj() + delta = utils.parse_past_timedelta(option) + query_object["inner_from_dttm"] = query_object["from_dttm"] + query_object["inner_to_dttm"] = query_object["to_dttm"] + + if not query_object["from_dttm"] or not query_object["to_dttm"]: + raise Exception( + _( + "`Since` and `Until` time bounds should be specified " + "when using the `Time Shift` feature." + ) + ) + query_object["from_dttm"] -= delta + query_object["to_dttm"] -= delta + + df2 = self.get_df_payload(query_object, time_compare=option).get("df") + if df2 is not None and DTTM_ALIAS in df2: + label = "{} offset".format(option) + df2[DTTM_ALIAS] += delta + df2 = self.process_data(df2) + self._extra_chart_data.append((label, df2)) + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + comparison_type = fd.get("comparison_type") or "values" + df = self.process_data(df) + if comparison_type == "values": + # Filter out series with all NaN + chart_data = self.to_series(df.dropna(axis=1, how="all")) + + for i, (label, df2) in enumerate(self._extra_chart_data): + chart_data.extend( + self.to_series( + df2, classed="time-shift-{}".format(i), title_suffix=label + ) + ) + else: + chart_data = [] + for i, (label, df2) in enumerate(self._extra_chart_data): + # reindex df2 into the df2 index + combined_index = df.index.union(df2.index) + df2 = ( + df2.reindex(combined_index) + .interpolate(method="time") + .reindex(df.index) + ) + + if comparison_type == "absolute": + diff = df - df2 + elif comparison_type == "percentage": + diff = (df - df2) / df2 + elif comparison_type == "ratio": + diff = df / df2 + else: + raise Exception( + "Invalid `comparison_type`: {0}".format(comparison_type) + ) + + # remove leading/trailing NaNs from the time shift difference + diff = diff[diff.first_valid_index() : diff.last_valid_index()] + + chart_data.extend( + self.to_series( + diff, classed="time-shift-{}".format(i), title_suffix=label + ) + ) + + if not self.sort_series: + chart_data = sorted(chart_data, key=lambda x: tuple(x["key"])) + return chart_data + + +class MultiLineViz(NVD3Viz): + + """Pile on multiple line charts""" + + viz_type = "line_multi" + verbose_name = _("Time Series - Multiple Line Charts") + + is_timeseries = True + + def query_obj(self): + return None + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + # Late imports to avoid circular import issues + from superset.models.slice import Slice + from superset import db + + slice_ids1 = fd.get("line_charts") + slices1 = db.session.query(Slice).filter(Slice.id.in_(slice_ids1)).all() + slice_ids2 = fd.get("line_charts_2") + slices2 = db.session.query(Slice).filter(Slice.id.in_(slice_ids2)).all() + return { + "slices": { + "axis1": [slc.data for slc in slices1], + "axis2": [slc.data for slc in slices2], + } + } + + +class NVD3DualLineViz(NVD3Viz): + + """A rich line chart with dual axis""" + + viz_type = "dual_line" + verbose_name = _("Time Series - Dual Axis Line Chart") + sort_series = False + is_timeseries = True + + def query_obj(self): + d = super().query_obj() + m1 = self.form_data.get("metric") + m2 = self.form_data.get("metric_2") + d["metrics"] = [m1, m2] + if not m1: + raise Exception(_("Pick a metric for left axis!")) + if not m2: + raise Exception(_("Pick a metric for right axis!")) + if m1 == m2: + raise Exception( + _("Please choose different metrics" " on left and right axis") + ) + return d + + def to_series(self, df, classed=""): + cols = [] + for col in df.columns: + if col == "": + cols.append("N/A") + elif col is None: + cols.append("NULL") + else: + cols.append(col) + df.columns = cols + series = df.to_dict("series") + chart_data = [] + metrics = [self.form_data.get("metric"), self.form_data.get("metric_2")] + for i, m in enumerate(metrics): + m = utils.get_metric_name(m) + ys = series[m] + if df[m].dtype.kind not in "biufc": + continue + series_title = m + d = { + "key": series_title, + "classed": classed, + "values": [ + {"x": ds, "y": ys[ds] if ds in ys else None} for ds in df.index + ], + "yAxis": i + 1, + "type": "line", + } + chart_data.append(d) + return chart_data + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + + if self.form_data.get("granularity") == "all": + raise Exception(_("Pick a time granularity for your time series")) + + metric = utils.get_metric_name(fd.get("metric")) + metric_2 = utils.get_metric_name(fd.get("metric_2")) + df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2]) + + chart_data = self.to_series(df) + return chart_data + + +class NVD3TimeSeriesBarViz(NVD3TimeSeriesViz): + + """A bar chart where the x axis is time""" + + viz_type = "bar" + sort_series = True + verbose_name = _("Time Series - Bar Chart") + + +class NVD3TimePivotViz(NVD3TimeSeriesViz): + + """Time Series - Periodicity Pivot""" + + viz_type = "time_pivot" + sort_series = True + verbose_name = _("Time Series - Period Pivot") + + def query_obj(self): + d = super().query_obj() + d["metrics"] = [self.form_data.get("metric")] + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + df = self.process_data(df) + freq = to_offset(fd.get("freq")) + try: + freq = type(freq)(freq.n, normalize=True, **freq.kwds) + except ValueError: + freq = type(freq)(freq.n, **freq.kwds) + df.index.name = None + df[DTTM_ALIAS] = df.index.map(freq.rollback) + df["ranked"] = df[DTTM_ALIAS].rank(method="dense", ascending=False) - 1 + df.ranked = df.ranked.map(int) + df["series"] = "-" + df.ranked.map(str) + df["series"] = df["series"].str.replace("-0", "current") + rank_lookup = { + row["series"]: row["ranked"] for row in df.to_dict(orient="records") + } + max_ts = df[DTTM_ALIAS].max() + max_rank = df["ranked"].max() + df[DTTM_ALIAS] = df.index + (max_ts - df[DTTM_ALIAS]) + df = df.pivot_table( + index=DTTM_ALIAS, + columns="series", + values=utils.get_metric_name(fd.get("metric")), + ) + chart_data = self.to_series(df) + for serie in chart_data: + serie["rank"] = rank_lookup[serie["key"]] + serie["perc"] = 1 - (serie["rank"] / (max_rank + 1)) + return chart_data + + +class NVD3CompareTimeSeriesViz(NVD3TimeSeriesViz): + + """A line chart component where you can compare the % change over time""" + + viz_type = "compare" + verbose_name = _("Time Series - Percent Change") + + +class NVD3TimeSeriesStackedViz(NVD3TimeSeriesViz): + + """A rich stack area chart""" + + viz_type = "area" + verbose_name = _("Time Series - Stacked") + sort_series = True + pivot_fill_value = 0 + + +class DistributionPieViz(NVD3Viz): + + """Annoy visualization snobs with this controversial pie chart""" + + viz_type = "pie" + verbose_name = _("Distribution - NVD3 - Pie Chart") + is_timeseries = False + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + metric = self.metric_labels[0] + df = df.pivot_table(index=self.columns, values=[metric]) + df.sort_values(by=metric, ascending=False, inplace=True) + df = df.reset_index() + df.columns = ["x", "y"] + return df.to_dict(orient="records") + + +class HistogramViz(BaseViz): + + """Histogram""" + + viz_type = "histogram" + verbose_name = _("Histogram") + is_timeseries = False + + def query_obj(self): + """Returns the query object for this visualization""" + d = super().query_obj() + d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) + if not self.form_data.get("all_columns_x"): + raise Exception(_("Must have at least one numeric column specified")) + return d + + def labelify(self, keys, column): + if isinstance(keys, str): + keys = (keys,) + # removing undesirable characters + labels = [re.sub(r"\W+", r"_", k) for k in keys] + if len(self.columns) > 1: + # Only show numeric column in label if there are many + labels = [column] + labels + return "__".join(labels) + + def get_data(self, df: pd.DataFrame) -> VizData: + """Returns the chart data""" + groupby = self.form_data.get("groupby") + + if df.empty: + return None + + chart_data = [] + if groupby: + groups = df.groupby(groupby) + else: + groups = [((), df)] + for keys, data in groups: + chart_data.extend( + [ + { + "key": self.labelify(keys, column), + "values": data[column].tolist(), + } + for column in self.columns + ] + ) + return chart_data + + +class DistributionBarViz(DistributionPieViz): + + """A good old bar chart""" + + viz_type = "dist_bar" + verbose_name = _("Distribution - Bar Chart") + is_timeseries = False + + def query_obj(self): + # TODO: Refactor this plugin to either perform grouping or assume + # preaggretagion of metrics ("numeric columns") + d = super().query_obj() + fd = self.form_data + if not self.all_metrics: + raise Exception(_("Pick at least one metric")) + if not self.columns: + raise Exception(_("Pick at least one field for [Series]")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + metrics = self.metric_labels + # TODO: will require post transformation logic not currently available in + # /api/v1/query endpoint + columns = fd.get("columns") or [] + groupby = fd.get("groupby") or [] + + # pandas will throw away nulls when grouping/pivoting, + # so we substitute NULL_STRING for any nulls in the necessary columns + df[self.columns] = df[self.columns].fillna(value=NULL_STRING) + + row = df.groupby(groupby).sum()[metrics[0]].copy() + row.sort_values(ascending=False, inplace=True) + pt = df.pivot_table(index=groupby, columns=columns, values=metrics) + if fd.get("contribution"): + pt = pt.T + pt = (pt / pt.sum()).T + pt = pt.reindex(row.index) + chart_data = [] + for name, ys in pt.items(): + if pt[name].dtype.kind not in "biufc" or name in groupby: + continue + if isinstance(name, str): + series_title = name + else: + offset = 0 if len(metrics) > 1 else 1 + series_title = ", ".join([str(s) for s in name[offset:]]) + values = [] + for i, v in ys.items(): + x = i + if isinstance(x, (tuple, list)): + x = ", ".join([str(s) for s in x]) + else: + x = str(x) + values.append({"x": x, "y": v}) + d = {"key": series_title, "values": values} + chart_data.append(d) + return chart_data + + +class SunburstViz(BaseViz): + + """A multi level sunburst chart""" + + viz_type = "sunburst" + verbose_name = _("Sunburst") + is_timeseries = False + credits = ( + "Kerry Rodden " + '@bl.ocks.org' + ) + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + cols = fd.get("groupby") or [] + cols.extend(["m1", "m2"]) + metric = utils.get_metric_name(fd.get("metric")) + secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + if metric == secondary_metric or secondary_metric is None: + df.rename(columns={df.columns[-1]: "m1"}, inplace=True) + df["m2"] = df["m1"] + else: + df.rename(columns={df.columns[-2]: "m1"}, inplace=True) + df.rename(columns={df.columns[-1]: "m2"}, inplace=True) + + # Re-order the columns as the query result set column ordering may differ from + # that listed in the hierarchy. + df = df[cols] + return df.to_numpy().tolist() + + def query_obj(self): + qry = super().query_obj() + fd = self.form_data + qry["metrics"] = [fd["metric"]] + secondary_metric = fd.get("secondary_metric") + if secondary_metric and secondary_metric != fd["metric"]: + qry["metrics"].append(secondary_metric) + return qry + + +class SankeyViz(BaseViz): + + """A Sankey diagram that requires a parent-child dataset""" + + viz_type = "sankey" + verbose_name = _("Sankey") + is_timeseries = False + credits = 'd3-sankey on npm' + + def get_data(self, df: pd.DataFrame) -> VizData: + df.columns = ["source", "target", "value"] + df["source"] = df["source"].astype(str) + df["target"] = df["target"].astype(str) + recs = df.to_dict(orient="records") + + hierarchy: Dict[str, Set[str]] = defaultdict(set) + for row in recs: + hierarchy[row["source"]].add(row["target"]) + + def find_cycle(g): + """Whether there's a cycle in a directed graph""" + path = set() + + def visit(vertex): + path.add(vertex) + for neighbour in g.get(vertex, ()): + if neighbour in path or visit(neighbour): + return (vertex, neighbour) + path.remove(vertex) + + for v in g: + cycle = visit(v) + if cycle: + return cycle + + cycle = find_cycle(hierarchy) + if cycle: + raise Exception( + _( + "There's a loop in your Sankey, please provide a tree. " + "Here's a faulty link: {}" + ).format(cycle) + ) + return recs + + +class DirectedForceViz(BaseViz): + + """An animated directed force layout graph visualization""" + + viz_type = "directed_force" + verbose_name = _("Directed Force Layout") + credits = 'd3noob @bl.ocks.org' + is_timeseries = False + + def query_obj(self): + qry = super().query_obj() + if len(self.form_data["groupby"]) != 2: + raise Exception(_("Pick exactly 2 columns to 'Group By'")) + qry["metrics"] = [self.form_data["metric"]] + return qry + + def get_data(self, df: pd.DataFrame) -> VizData: + df.columns = ["source", "target", "value"] + return df.to_dict(orient="records") + + +class ChordViz(BaseViz): + + """A Chord diagram""" + + viz_type = "chord" + verbose_name = _("Directed Force Layout") + credits = 'Bostock' + is_timeseries = False + + def query_obj(self): + qry = super().query_obj() + fd = self.form_data + qry["metrics"] = [fd.get("metric")] + return qry + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + df.columns = ["source", "target", "value"] + + # Preparing a symetrical matrix like d3.chords calls for + nodes = list(set(df["source"]) | set(df["target"])) + matrix = {} + for source, target in product(nodes, nodes): + matrix[(source, target)] = 0 + for source, target, value in df.to_records(index=False): + matrix[(source, target)] = value + m = [[matrix[(n1, n2)] for n1 in nodes] for n2 in nodes] + return {"nodes": list(nodes), "matrix": m} + + +class CountryMapViz(BaseViz): + + """A country centric""" + + viz_type = "country_map" + verbose_name = _("Country Map") + is_timeseries = False + credits = "From bl.ocks.org By john-guerra" + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + cols = [fd.get("entity")] + metric = self.metric_labels[0] + cols += [metric] + ndf = df[cols] + df = ndf + df.columns = ["country_id", "metric"] + d = df.to_dict(orient="records") + return d + + +class WorldMapViz(BaseViz): + + """A country centric world map""" + + viz_type = "world_map" + verbose_name = _("World Map") + is_timeseries = False + credits = 'datamaps on npm' + + def get_data(self, df: pd.DataFrame) -> VizData: + from superset.examples import countries + + fd = self.form_data + cols = [fd.get("entity")] + metric = utils.get_metric_name(fd.get("metric")) + secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + columns = ["country", "m1", "m2"] + if metric == secondary_metric: + ndf = df[cols] + ndf["m1"] = df[metric] + ndf["m2"] = ndf["m1"] + else: + if secondary_metric: + cols += [metric, secondary_metric] + else: + cols += [metric] + columns = ["country", "m1"] + ndf = df[cols] + df = ndf + df.columns = columns + d = df.to_dict(orient="records") + for row in d: + country = None + if isinstance(row["country"], str): + country = countries.get(fd.get("country_fieldtype"), row["country"]) + + if country: + row["country"] = country["cca3"] + row["latitude"] = country["lat"] + row["longitude"] = country["lng"] + row["name"] = country["name"] + else: + row["country"] = "XXX" + return d + + +class FilterBoxViz(BaseViz): + + """A multi filter, multi-choice filter box to make dashboards interactive""" + + viz_type = "filter_box" + verbose_name = _("Filters") + is_timeseries = False + credits = 'a Superset original' + cache_type = "get_data" + filter_row_limit = 1000 + + def query_obj(self): + return None + + def run_extra_queries(self): + qry = super().query_obj() + filters = self.form_data.get("filter_configs") or [] + qry["row_limit"] = self.filter_row_limit + self.dataframes = {} + for flt in filters: + col = flt.get("column") + if not col: + raise Exception( + _("Invalid filter configuration, please select a column") + ) + qry["columns"] = [col] + metric = flt.get("metric") + qry["metrics"] = [metric] if metric else [] + df = self.get_df_payload(query_obj=qry).get("df") + self.dataframes[col] = df + + def get_data(self, df: pd.DataFrame) -> VizData: + filters = self.form_data.get("filter_configs") or [] + d = {} + for flt in filters: + col = flt.get("column") + metric = flt.get("metric") + df = self.dataframes.get(col) + if df is not None: + if metric: + df = df.sort_values( + utils.get_metric_name(metric), ascending=flt.get("asc") + ) + d[col] = [ + {"id": row[0], "text": row[0], "metric": row[1]} + for row in df.itertuples(index=False) + ] + else: + df = df.sort_values(col, ascending=flt.get("asc")) + d[col] = [ + {"id": row[0], "text": row[0]} + for row in df.itertuples(index=False) + ] + return d + + +class IFrameViz(BaseViz): + + """You can squeeze just about anything in this iFrame component""" + + viz_type = "iframe" + verbose_name = _("iFrame") + credits = 'a Superset original' + is_timeseries = False + + def query_obj(self): + return None + + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + return pd.DataFrame() + + def get_data(self, df: pd.DataFrame) -> VizData: + return {"iframe": True} + + +class ParallelCoordinatesViz(BaseViz): + + """Interactive parallel coordinate implementation + + Uses this amazing javascript library + https://github.com/syntagmatic/parallel-coordinates + """ + + viz_type = "para" + verbose_name = _("Parallel Coordinates") + credits = ( + '' + "Syntagmatic's library" + ) + is_timeseries = False + + def get_data(self, df: pd.DataFrame) -> VizData: + return df.to_dict(orient="records") + + +class HeatmapViz(BaseViz): + + """A nice heatmap visualization that support high density through canvas""" + + viz_type = "heatmap" + verbose_name = _("Heatmap") + is_timeseries = False + credits = ( + 'inspired from mbostock @' + "bl.ocks.org" + ) + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + x = fd.get("all_columns_x") + y = fd.get("all_columns_y") + v = self.metric_labels[0] + if x == y: + df.columns = ["x", "y", "v"] + else: + df = df[[x, y, v]] + df.columns = ["x", "y", "v"] + norm = fd.get("normalize_across") + overall = False + max_ = df.v.max() + min_ = df.v.min() + if norm == "heatmap": + overall = True + else: + gb = df.groupby(norm, group_keys=False) + if len(gb) <= 1: + overall = True + else: + df["perc"] = gb.apply( + lambda x: (x.v - x.v.min()) / (x.v.max() - x.v.min()) + ) + df["rank"] = gb.apply(lambda x: x.v.rank(pct=True)) + if overall: + df["perc"] = (df.v - min_) / (max_ - min_) + df["rank"] = df.v.rank(pct=True) + return {"records": df.to_dict(orient="records"), "extents": [min_, max_]} + + +class HorizonViz(NVD3TimeSeriesViz): + + """Horizon chart + + https://www.npmjs.com/package/d3-horizon-chart + """ + + viz_type = "horizon" + verbose_name = _("Horizon Charts") + credits = ( + '' + "d3-horizon-chart" + ) + + +class MapboxViz(BaseViz): + + """Rich maps made with Mapbox""" + + viz_type = "mapbox" + verbose_name = _("Mapbox") + is_timeseries = False + credits = "Mapbox GL JS" + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + label_col = fd.get("mapbox_label") + + if not fd.get("groupby"): + if fd.get("all_columns_x") is None or fd.get("all_columns_y") is None: + raise Exception(_("[Longitude] and [Latitude] must be set")) + d["columns"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] + + if label_col and len(label_col) >= 1: + if label_col[0] == "count": + raise Exception( + _( + "Must have a [Group By] column to have 'count' as the " + + "[Label]" + ) + ) + d["columns"].append(label_col[0]) + + if fd.get("point_radius") != "Auto": + d["columns"].append(fd.get("point_radius")) + + d["columns"] = list(set(d["columns"])) + else: + # Ensuring columns chosen are all in group by + if ( + label_col + and len(label_col) >= 1 + and label_col[0] != "count" + and label_col[0] not in fd.get("groupby") + ): + raise Exception(_("Choice of [Label] must be present in [Group By]")) + + if fd.get("point_radius") != "Auto" and fd.get( + "point_radius" + ) not in fd.get("groupby"): + raise Exception( + _("Choice of [Point Radius] must be present in [Group By]") + ) + + if fd.get("all_columns_x") not in fd.get("groupby") or fd.get( + "all_columns_y" + ) not in fd.get("groupby"): + raise Exception( + _( + "[Longitude] and [Latitude] columns must be present in " + + "[Group By]" + ) + ) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + label_col = fd.get("mapbox_label") + has_custom_metric = label_col is not None and len(label_col) > 0 + metric_col = [None] * len(df.index) + if has_custom_metric: + if label_col[0] == fd.get("all_columns_x"): # type: ignore + metric_col = df[fd.get("all_columns_x")] + elif label_col[0] == fd.get("all_columns_y"): # type: ignore + metric_col = df[fd.get("all_columns_y")] + else: + metric_col = df[label_col[0]] # type: ignore + point_radius_col = ( + [None] * len(df.index) + if fd.get("point_radius") == "Auto" + else df[fd.get("point_radius")] + ) + + # limiting geo precision as long decimal values trigger issues + # around json-bignumber in Mapbox + GEO_PRECISION = 10 + # using geoJSON formatting + geo_json = { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": {"metric": metric, "radius": point_radius}, + "geometry": { + "type": "Point", + "coordinates": [ + round(lon, GEO_PRECISION), + round(lat, GEO_PRECISION), + ], + }, + } + for lon, lat, metric, point_radius in zip( + df[fd.get("all_columns_x")], + df[fd.get("all_columns_y")], + metric_col, + point_radius_col, + ) + ], + } + + x_series, y_series = df[fd.get("all_columns_x")], df[fd.get("all_columns_y")] + south_west = [x_series.min(), y_series.min()] + north_east = [x_series.max(), y_series.max()] + + return { + "geoJSON": geo_json, + "hasCustomMetric": has_custom_metric, + "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapStyle": fd.get("mapbox_style"), + "aggregatorName": fd.get("pandas_aggfunc"), + "clusteringRadius": fd.get("clustering_radius"), + "pointRadiusUnit": fd.get("point_radius_unit"), + "globalOpacity": fd.get("global_opacity"), + "bounds": [south_west, north_east], + "renderWhileDragging": fd.get("render_while_dragging"), + "tooltip": fd.get("rich_tooltip"), + "color": fd.get("mapbox_color"), + } + + +class DeckGLMultiLayer(BaseViz): + + """Pile on multiple DeckGL layers""" + + viz_type = "deck_multi" + verbose_name = _("Deck.gl - Multiple Layers") + + is_timeseries = False + credits = 'deck.gl' + + def query_obj(self): + return None + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + # Late imports to avoid circular import issues + from superset.models.slice import Slice + from superset import db + + slice_ids = fd.get("deck_slices") + slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + return { + "mapboxApiKey": config["MAPBOX_API_KEY"], + "slices": [slc.data for slc in slices], + } + + +class BaseDeckGLViz(BaseViz): + + """Base class for deck.gl visualizations""" + + is_timeseries = False + credits = 'deck.gl' + spatial_control_keys: List[str] = [] + + def get_metrics(self): + self.metric = self.form_data.get("size") + return [self.metric] if self.metric else [] + + @staticmethod + def parse_coordinates(s): + if not s: + return None + try: + p = Point(s) + return (p.latitude, p.longitude) # pylint: disable=no-member + except Exception: + raise SpatialException(_("Invalid spatial point encountered: %s" % s)) + + @staticmethod + def reverse_geohash_decode(geohash_code): + lat, lng = geohash.decode(geohash_code) + return (lng, lat) + + @staticmethod + def reverse_latlong(df, key): + df[key] = [tuple(reversed(o)) for o in df[key] if isinstance(o, (list, tuple))] + + def process_spatial_data_obj(self, key, df): + spatial = self.form_data.get(key) + if spatial is None: + raise ValueError(_("Bad spatial key")) + + if spatial.get("type") == "latlong": + df[key] = list( + zip( + pd.to_numeric(df[spatial.get("lonCol")], errors="coerce"), + pd.to_numeric(df[spatial.get("latCol")], errors="coerce"), + ) + ) + elif spatial.get("type") == "delimited": + lon_lat_col = spatial.get("lonlatCol") + df[key] = df[lon_lat_col].apply(self.parse_coordinates) + del df[lon_lat_col] + elif spatial.get("type") == "geohash": + df[key] = df[spatial.get("geohashCol")].map(self.reverse_geohash_decode) + del df[spatial.get("geohashCol")] + + if spatial.get("reverseCheckbox"): + self.reverse_latlong(df, key) + + if df.get(key) is None: + raise NullValueException( + _( + "Encountered invalid NULL spatial entry, \ + please consider filtering those out" + ) + ) + return df + + def add_null_filters(self): + fd = self.form_data + spatial_columns = set() + + if fd.get("adhoc_filters") is None: + fd["adhoc_filters"] = [] + + line_column = fd.get("line_column") + if line_column: + spatial_columns.add(line_column) + + for column in sorted(spatial_columns): + filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""}) + fd["adhoc_filters"].append(filter_) + + def query_obj(self): + fd = self.form_data + + # add NULL filters + if fd.get("filter_nulls", True): + self.add_null_filters() + + d = super().query_obj() + gb = [] + + if fd.get("dimension"): + gb += [fd.get("dimension")] + + if fd.get("js_columns"): + gb += fd.get("js_columns") + metrics = self.get_metrics() + gb = list(set(gb)) + if metrics: + d["metrics"] = metrics + return d + + def get_js_columns(self, d): + cols = self.form_data.get("js_columns") or [] + return {col: d.get(col) for col in cols} + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + # Processing spatial info + for key in self.spatial_control_keys: + df = self.process_spatial_data_obj(key, df) + + features = [] + for d in df.to_dict(orient="records"): + feature = self.get_properties(d) + extra_props = self.get_js_columns(d) + if extra_props: + feature["extraProps"] = extra_props + features.append(feature) + + return { + "features": features, + "mapboxApiKey": config["MAPBOX_API_KEY"], + "metricLabels": self.metric_labels, + } + + def get_properties(self, d): + raise NotImplementedError() + + +class DeckScatterViz(BaseDeckGLViz): + + """deck.gl's ScatterLayer""" + + viz_type = "deck_scatter" + verbose_name = _("Deck.gl - Scatter plot") + spatial_control_keys = ["spatial"] + is_timeseries = True + + def query_obj(self): + fd = self.form_data + self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) + self.point_radius_fixed = fd.get("point_radius_fixed") or { + "type": "fix", + "value": 500, + } + return super().query_obj() + + def get_metrics(self): + self.metric = None + if self.point_radius_fixed.get("type") == "metric": + self.metric = self.point_radius_fixed.get("value") + return [self.metric] + return None + + def get_properties(self, d): + return { + "metric": d.get(self.metric_label), + "radius": self.fixed_value + if self.fixed_value + else d.get(self.metric_label), + "cat_color": d.get(self.dim) if self.dim else None, + "position": d.get("spatial"), + DTTM_ALIAS: d.get(DTTM_ALIAS), + } + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + self.metric_label = utils.get_metric_name(self.metric) if self.metric else None + self.point_radius_fixed = fd.get("point_radius_fixed") + self.fixed_value = None + self.dim = self.form_data.get("dimension") + if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric": + self.fixed_value = self.point_radius_fixed.get("value") + return super().get_data(df) + + +class DeckScreengrid(BaseDeckGLViz): + + """deck.gl's ScreenGridLayer""" + + viz_type = "deck_screengrid" + verbose_name = _("Deck.gl - Screen Grid") + spatial_control_keys = ["spatial"] + is_timeseries = True + + def query_obj(self): + fd = self.form_data + self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") + return super().query_obj() + + def get_properties(self, d): + return { + "position": d.get("spatial"), + "weight": d.get(self.metric_label) or 1, + "__timestamp": d.get(DTTM_ALIAS) or d.get("__time"), + } + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super().get_data(df) + + +class DeckGrid(BaseDeckGLViz): + + """deck.gl's DeckLayer""" + + viz_type = "deck_grid" + verbose_name = _("Deck.gl - 3D Grid") + spatial_control_keys = ["spatial"] + + def get_properties(self, d): + return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super().get_data(df) + + +def geohash_to_json(geohash_code): + p = geohash.bbox(geohash_code) + return [ + [p.get("w"), p.get("n")], + [p.get("e"), p.get("n")], + [p.get("e"), p.get("s")], + [p.get("w"), p.get("s")], + [p.get("w"), p.get("n")], + ] + + +class DeckPathViz(BaseDeckGLViz): + + """deck.gl's PathLayer""" + + viz_type = "deck_path" + verbose_name = _("Deck.gl - Paths") + deck_viz_key = "path" + is_timeseries = True + deser_map = { + "json": json.loads, + "polyline": polyline.decode, + "geohash": geohash_to_json, + } + + def query_obj(self): + fd = self.form_data + self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") + d = super().query_obj() + self.metric = fd.get("metric") + if d["metrics"]: + self.has_metrics = True + else: + self.has_metrics = False + return d + + def get_properties(self, d): + fd = self.form_data + line_type = fd.get("line_type") + deser = self.deser_map[line_type] + line_column = fd.get("line_column") + path = deser(d[line_column]) + if fd.get("reverse_long_lat"): + path = [(o[1], o[0]) for o in path] + d[self.deck_viz_key] = path + if line_type != "geohash": + del d[line_column] + d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time") + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super().get_data(df) + + +class DeckPolygon(DeckPathViz): + + """deck.gl's Polygon Layer""" + + viz_type = "deck_polygon" + deck_viz_key = "polygon" + verbose_name = _("Deck.gl - Polygon") + + def query_obj(self): + fd = self.form_data + self.elevation = fd.get("point_radius_fixed") or {"type": "fix", "value": 500} + return super().query_obj() + + def get_metrics(self): + metrics = [self.form_data.get("metric")] + if self.elevation.get("type") == "metric": + metrics.append(self.elevation.get("value")) + return [metric for metric in metrics if metric] + + def get_properties(self, d): + super().get_properties(d) + fd = self.form_data + elevation = fd["point_radius_fixed"]["value"] + type_ = fd["point_radius_fixed"]["type"] + d["elevation"] = ( + d.get(utils.get_metric_name(elevation)) if type_ == "metric" else elevation + ) + return d + + +class DeckHex(BaseDeckGLViz): + + """deck.gl's DeckLayer""" + + viz_type = "deck_hex" + verbose_name = _("Deck.gl - 3D HEX") + spatial_control_keys = ["spatial"] + + def get_properties(self, d): + return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super(DeckHex, self).get_data(df) + + +class DeckGeoJson(BaseDeckGLViz): + + """deck.gl's GeoJSONLayer""" + + viz_type = "deck_geojson" + verbose_name = _("Deck.gl - GeoJSON") + + def query_obj(self): + d = super().query_obj() + d["columns"] += [self.form_data.get("geojson")] + d["metrics"] = [] + return d + + def get_properties(self, d): + geojson = d.get(self.form_data.get("geojson")) + return json.loads(geojson) + + +class DeckArc(BaseDeckGLViz): + + """deck.gl's Arc Layer""" + + viz_type = "deck_arc" + verbose_name = _("Deck.gl - Arc") + spatial_control_keys = ["start_spatial", "end_spatial"] + is_timeseries = True + + def query_obj(self): + fd = self.form_data + self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) + return super().query_obj() + + def get_properties(self, d): + dim = self.form_data.get("dimension") + return { + "sourcePosition": d.get("start_spatial"), + "targetPosition": d.get("end_spatial"), + "cat_color": d.get(dim) if dim else None, + DTTM_ALIAS: d.get(DTTM_ALIAS), + } + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + d = super().get_data(df) + + return { + "features": d["features"], # type: ignore + "mapboxApiKey": config["MAPBOX_API_KEY"], + } + + +class EventFlowViz(BaseViz): + + """A visualization to explore patterns in event sequences""" + + viz_type = "event_flow" + verbose_name = _("Event flow") + credits = 'from @data-ui' + is_timeseries = True + + def query_obj(self): + query = super().query_obj() + form_data = self.form_data + + event_key = form_data.get("all_columns_x") + entity_key = form_data.get("entity") + meta_keys = [ + col + for col in form_data.get("all_columns") + if col != event_key and col != entity_key + ] + + query["columns"] = [event_key, entity_key] + meta_keys + + if form_data["order_by_entity"]: + query["orderby"] = [(entity_key, True)] + + return query + + def get_data(self, df: pd.DataFrame) -> VizData: + return df.to_dict(orient="records") + + +class PairedTTestViz(BaseViz): + + """A table displaying paired t-test values""" + + viz_type = "paired_ttest" + verbose_name = _("Time Series - Paired t-test") + sort_series = False + is_timeseries = True + + def get_data(self, df: pd.DataFrame) -> VizData: + """ + Transform received data frame into an object of the form: + { + 'metric1': [ + { + groups: ('groupA', ... ), + values: [ {x, y}, ... ], + }, ... + ], ... + } + """ + + if df.empty: + return None + + fd = self.form_data + groups = fd.get("groupby") + metrics = self.metric_labels + df = df.pivot_table(index=DTTM_ALIAS, columns=groups, values=metrics) + cols = [] + # Be rid of falsey keys + for col in df.columns: + if col == "": + cols.append("N/A") + elif col is None: + cols.append("NULL") + else: + cols.append(col) + df.columns = cols + data: Dict = {} + series = df.to_dict("series") + for nameSet in df.columns: + # If no groups are defined, nameSet will be the metric name + hasGroup = not isinstance(nameSet, str) + Y = series[nameSet] + d = { + "group": nameSet[1:] if hasGroup else "All", + "values": [{"x": t, "y": Y[t] if t in Y else None} for t in df.index], + } + key = nameSet[0] if hasGroup else nameSet + if key in data: + data[key].append(d) + else: + data[key] = [d] + return data + + +class RoseViz(NVD3TimeSeriesViz): + + viz_type = "rose" + verbose_name = _("Time Series - Nightingale Rose Chart") + sort_series = False + is_timeseries = True + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + data = super().get_data(df) + result: Dict = {} + for datum in data: # type: ignore + key = datum["key"] + for val in datum["values"]: + timestamp = val["x"].value + if not result.get(timestamp): + result[timestamp] = [] + value = 0 if math.isnan(val["y"]) else val["y"] + result[timestamp].append( + { + "key": key, + "value": value, + "name": ", ".join(key) if isinstance(key, list) else key, + "time": val["x"], + } + ) + return result + + +class PartitionViz(NVD3TimeSeriesViz): + + """ + A hierarchical data visualization with support for time series. + """ + + viz_type = "partition" + verbose_name = _("Partition Diagram") + + def query_obj(self): + query_obj = super().query_obj() + time_op = self.form_data.get("time_series_option", "not_time") + # Return time series data if the user specifies so + query_obj["is_timeseries"] = time_op != "not_time" + return query_obj + + def levels_for(self, time_op, groups, df): + """ + Compute the partition at each `level` from the dataframe. + """ + levels = {} + for i in range(0, len(groups) + 1): + agg_df = df.groupby(groups[:i]) if i else df + levels[i] = ( + agg_df.mean() + if time_op == "agg_mean" + else agg_df.sum(numeric_only=True) + ) + return levels + + def levels_for_diff(self, time_op, groups, df): + # Obtain a unique list of the time grains + times = list(set(df[DTTM_ALIAS])) + times.sort() + until = times[len(times) - 1] + since = times[0] + # Function describing how to calculate the difference + func = { + "point_diff": [pd.Series.sub, lambda a, b, fill_value: a - b], + "point_factor": [pd.Series.div, lambda a, b, fill_value: a / float(b)], + "point_percent": [ + lambda a, b, fill_value=0: a.div(b, fill_value=fill_value) - 1, + lambda a, b, fill_value: a / float(b) - 1, + ], + }[time_op] + agg_df = df.groupby(DTTM_ALIAS).sum() + levels = { + 0: pd.Series( + { + m: func[1](agg_df[m][until], agg_df[m][since], 0) + for m in agg_df.columns + } + ) + } + for i in range(1, len(groups) + 1): + agg_df = df.groupby([DTTM_ALIAS] + groups[:i]).sum() + levels[i] = pd.DataFrame( + { + m: func[0](agg_df[m][until], agg_df[m][since], fill_value=0) + for m in agg_df.columns + } + ) + return levels + + def levels_for_time(self, groups, df): + procs = {} + for i in range(0, len(groups) + 1): + self.form_data["groupby"] = groups[:i] + df_drop = df.drop(groups[i:], 1) + procs[i] = self.process_data(df_drop, aggregate=True) + self.form_data["groupby"] = groups + return procs + + def nest_values(self, levels, level=0, metric=None, dims=()): + """ + Nest values at each level on the back-end with + access and setting, instead of summing from the bottom. + """ + if not level: + return [ + { + "name": m, + "val": levels[0][m], + "children": self.nest_values(levels, 1, m), + } + for m in levels[0].index + ] + if level == 1: + return [ + { + "name": i, + "val": levels[1][metric][i], + "children": self.nest_values(levels, 2, metric, (i,)), + } + for i in levels[1][metric].index + ] + if level >= len(levels): + return [] + return [ + { + "name": i, + "val": levels[level][metric][dims][i], + "children": self.nest_values(levels, level + 1, metric, dims + (i,)), + } + for i in levels[level][metric][dims].index + ] + + def nest_procs(self, procs, level=-1, dims=(), time=None): + if level == -1: + return [ + {"name": m, "children": self.nest_procs(procs, 0, (m,))} + for m in procs[0].columns + ] + if not level: + return [ + { + "name": t, + "val": procs[0][dims[0]][t], + "children": self.nest_procs(procs, 1, dims, t), + } + for t in procs[0].index + ] + if level >= len(procs): + return [] + return [ + { + "name": i, + "val": procs[level][dims][i][time], + "children": self.nest_procs(procs, level + 1, dims + (i,), time), + } + for i in procs[level][dims].columns + ] + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + groups = fd.get("groupby", []) + time_op = fd.get("time_series_option", "not_time") + if not len(groups): + raise ValueError("Please choose at least one groupby") + if time_op == "not_time": + levels = self.levels_for("agg_sum", groups, df) + elif time_op in ["agg_sum", "agg_mean"]: + levels = self.levels_for(time_op, groups, df) + elif time_op in ["point_diff", "point_factor", "point_percent"]: + levels = self.levels_for_diff(time_op, groups, df) + elif time_op == "adv_anal": + procs = self.levels_for_time(groups, df) + return self.nest_procs(procs) + else: + levels = self.levels_for("agg_sum", [DTTM_ALIAS] + groups, df) + return self.nest_values(levels) + + +viz_types = { + o.viz_type: o + for o in globals().values() + if ( + inspect.isclass(o) + and issubclass(o, BaseViz) + and o.viz_type not in config["VIZ_TYPE_BLACKLIST"] + ) +} From e233da1e0644b9a9bfcdb49c4b55aeca3ae563b8 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 6 Apr 2020 17:22:16 +0300 Subject: [PATCH 08/14] Break out further SIP-38 related tests --- superset/connectors/druid/models.py | 16 +- superset/viz_sip38.py | 41 +- tests/druid_func_tests.py | 32 +- tests/druid_func_tests_sip38.py | 1157 ++++++++++++++++++++++++++ tests/model_tests.py | 6 +- tests/security_tests.py | 2 + tests/sqla_models_tests.py | 4 +- tests/viz_tests.py | 114 ++- tests/viz_tests_sip38.py | 1175 +++++++++++++++++++++++++++ 9 files changed, 2480 insertions(+), 67 deletions(-) create mode 100644 tests/druid_func_tests_sip38.py create mode 100644 tests/viz_tests_sip38.py diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 4f36987a4d09..db1d674c8794 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1247,14 +1247,9 @@ def run_query( # druid del qry["dimensions"] client.timeseries(**qry) elif ( - IS_SIP_38 - and not having_filters - and len(columns) == 1 - and order_desc - or not IS_SIP_38 - and not having_filters - and len(groupby) == 1 + not having_filters and order_desc + and (IS_SIP_38 and len(columns) == 1 or not IS_SIP_38 and len(groupby) == 1) ): dim = list(qry["dimensions"])[0] logger.info("Running two-phase topn query for dimension [{}]".format(dim)) @@ -1308,10 +1303,7 @@ def run_query( # druid logger.info("Phase 2 Complete") elif ( having_filters - or IS_SIP_38 - and columns - or not IS_SIP_38 - and len(groupby) > 0 + or (IS_SIP_38 and columns or not IS_SIP_38 and len(groupby)) > 0 ): # If grouping on multiple fields or using a having filter # we have to force a groupby query @@ -1398,7 +1390,7 @@ def run_query( # druid @staticmethod def homogenize_types(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: - """Converting all GROUPBY columns to strings + """Converting all columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns strings in the dataframe. This creates issues downstream related diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index 9556aaa1550c..87950d66e367 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -85,6 +85,7 @@ "columns", "dimension", "entity", + "geojson", "groupby", "series", "line_column", @@ -125,14 +126,15 @@ def __init__( # merge all selectable columns into `columns` property self.columns: List[str] = [] for key in COLUMN_FORM_DATA_PARAMS: - logger.warning( - f"The form field %s is deprecated. Viz plugins should " - f"pass all selectables via the columns field", - key, - ) value = self.form_data.get(key) or [] value_list = value if isinstance(value, list) else [value] - self.columns += value_list + if value_list: + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + self.columns += value_list for key in SPATIAL_COLUMN_FORM_DATA_PARAMS: spatial = self.form_data.get(key) @@ -2279,27 +2281,6 @@ def add_null_filters(self): filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""}) fd["adhoc_filters"].append(filter_) - def query_obj(self): - fd = self.form_data - - # add NULL filters - if fd.get("filter_nulls", True): - self.add_null_filters() - - d = super().query_obj() - gb = [] - - if fd.get("dimension"): - gb += [fd.get("dimension")] - - if fd.get("js_columns"): - gb += fd.get("js_columns") - metrics = self.get_metrics() - gb = list(set(gb)) - if metrics: - d["metrics"] = metrics - return d - def get_js_columns(self, d): cols = self.form_data.get("js_columns") or [] return {col: d.get(col) for col in cols} @@ -2527,12 +2508,6 @@ class DeckGeoJson(BaseDeckGLViz): viz_type = "deck_geojson" verbose_name = _("Deck.gl - GeoJSON") - def query_obj(self): - d = super().query_obj() - d["columns"] += [self.form_data.get("geojson")] - d["metrics"] = [] - return d - def get_properties(self, d): geojson = d.get(self.form_data.get("geojson")) return json.loads(geojson) diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 5895a1b386b0..699afeaec835 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -407,7 +407,7 @@ def test_run_query_no_groupby(self): aggs = [] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - columns = [] + groupby = [] metrics = ["metric1"] ds.get_having_filters = Mock(return_value=[]) client.query_builder = Mock() @@ -415,11 +415,11 @@ def test_run_query_no_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - columns, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, filter=[], row_limit=100, @@ -456,7 +456,7 @@ def test_run_query_with_adhoc_metric(self): all_metrics = [] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) - columns = [] + groupby = [] metrics = [ { "expressionType": "SIMPLE", @@ -472,11 +472,11 @@ def test_run_query_with_adhoc_metric(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - columns, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, filter=[], row_limit=100, @@ -513,17 +513,17 @@ def test_run_query_single_groupby(self): aggs = ["metric1"] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - columns = ["col1"] + groupby = ["col1"] metrics = ["metric1"] ds.get_having_filters = Mock(return_value=[]) client.query_builder.last_query.query_dict = {"mock": 0} # client.topn is called twice ds.run_query( - columns, metrics, None, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=100, client=client, order_desc=True, @@ -543,11 +543,11 @@ def test_run_query_single_groupby(self): client = Mock() client.query_builder.last_query.query_dict = {"mock": 0} ds.run_query( - columns, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, order_desc=False, filter=[], @@ -568,11 +568,11 @@ def test_run_query_single_groupby(self): client = Mock() client.query_builder.last_query.query_dict = {"mock": 0} ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, order_desc=True, timeseries_limit=5, @@ -611,7 +611,7 @@ def test_run_query_multiple_groupby(self): aggs = [] post_aggs = ["some_agg"] ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - columns = ["col1", "col2"] + groupby = ["col1", "col2"] metrics = ["metric1"] ds.get_having_filters = Mock(return_value=[]) client.query_builder = Mock() @@ -619,11 +619,11 @@ def test_run_query_multiple_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - columns, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, row_limit=100, filter=[], @@ -1016,16 +1016,16 @@ def test_run_query_order_by_metrics(self): ds.columns = [dim1, dim2] ds.metrics = list(metrics_dict.values()) - columns = ["dim1"] + groupby = ["dim1"] metrics = ["count1"] granularity = "all" # get the counts of the top 5 'dim1's, order by 'sum1' ds.run_query( - columns, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="sum1", client=client, @@ -1042,11 +1042,11 @@ def test_run_query_order_by_metrics(self): # get the counts of the top 5 'dim1's, order by 'div1' ds.run_query( - columns, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="div1", client=client, @@ -1061,14 +1061,14 @@ def test_run_query_order_by_metrics(self): self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) self.assertEqual({"div1"}, set(post_aggregations.keys())) - columns = ["dim1", "dim2"] + groupby = ["dim1", "dim2"] # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' ds.run_query( - columns, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="sum1", client=client, @@ -1085,11 +1085,11 @@ def test_run_query_order_by_metrics(self): # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' ds.run_query( - columns, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="div1", client=client, diff --git a/tests/druid_func_tests_sip38.py b/tests/druid_func_tests_sip38.py new file mode 100644 index 000000000000..058d8c1743bd --- /dev/null +++ b/tests/druid_func_tests_sip38.py @@ -0,0 +1,1157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +import json +import unittest +from unittest.mock import Mock, patch + +import tests.test_app +import superset.connectors.druid.models as models +from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric +from superset.exceptions import SupersetException + +from .base_tests import SupersetTestCase + +try: + from pydruid.utils.dimensions import ( + MapLookupExtraction, + RegexExtraction, + RegisteredLookupExtraction, + TimeFormatExtraction, + ) + import pydruid.utils.postaggregator as postaggs +except ImportError: + pass + + +def mock_metric(metric_name, is_postagg=False): + metric = Mock() + metric.metric_name = metric_name + metric.metric_type = "postagg" if is_postagg else "metric" + return metric + + +def emplace(metrics_dict, metric_name, is_postagg=False): + metrics_dict[metric_name] = mock_metric(metric_name, is_postagg) + + +# Unit tests that can be run without initializing base tests +@patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"SIP_38_VIZ_REARCHITECTURE": True}, + clear=True, +) +class DruidFuncTestCase(SupersetTestCase): + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_map(self): + filters = [{"col": "deviceName", "val": ["iPhone X"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "device", + "outputName": "deviceName", + "outputType": "STRING", + "extractionFn": { + "type": "lookup", + "dimension": "dimensionName", + "outputName": "dimensionOutputName", + "replaceMissingValueWith": "missing_value", + "retainMissingValue": False, + "lookup": { + "type": "map", + "map": { + "iPhone10,1": "iPhone 8", + "iPhone10,4": "iPhone 8", + "iPhone10,2": "iPhone 8 Plus", + "iPhone10,5": "iPhone 8 Plus", + "iPhone10,3": "iPhone X", + "iPhone10,6": "iPhone X", + }, + "isOneToOne": False, + }, + }, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="deviceName", dimension_spec_json=spec_json) + column_dict = {"deviceName": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, MapLookupExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + f_ext_fn = f.extraction_function + self.assertEqual(dim_ext_fn["lookup"]["map"], f_ext_fn._mapping) + self.assertEqual(dim_ext_fn["lookup"]["isOneToOne"], f_ext_fn._injective) + self.assertEqual( + dim_ext_fn["replaceMissingValueWith"], f_ext_fn._replace_missing_values + ) + self.assertEqual( + dim_ext_fn["retainMissingValue"], f_ext_fn._retain_missing_values + ) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_regex(self): + filters = [{"col": "buildPrefix", "val": ["22B"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "build", + "outputName": "buildPrefix", + "outputType": "STRING", + "extractionFn": {"type": "regex", "expr": "(^[0-9A-Za-z]{3})"}, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="buildPrefix", dimension_spec_json=spec_json) + column_dict = {"buildPrefix": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, RegexExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + f_ext_fn = f.extraction_function + self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_registered_lookup_extraction(self): + filters = [{"col": "country", "val": ["Spain"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "country_name", + "outputName": "country", + "outputType": "STRING", + "extractionFn": {"type": "registeredLookup", "lookup": "country_name"}, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="country", dimension_spec_json=spec_json) + column_dict = {"country": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, RegisteredLookupExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) + self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_time_format(self): + filters = [{"col": "dayOfMonth", "val": ["1", "20"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "__time", + "outputName": "dayOfMonth", + "extractionFn": { + "type": "timeFormat", + "format": "d", + "timeZone": "Asia/Kolkata", + "locale": "en", + }, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="dayOfMonth", dimension_spec_json=spec_json) + column_dict = {"dayOfMonth": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, TimeFormatExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) + self.assertEqual(dim_ext_fn["format"], f.extraction_function._format) + self.assertEqual(dim_ext_fn["timeZone"], f.extraction_function._time_zone) + self.assertEqual(dim_ext_fn["locale"], f.extraction_function._locale) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_ignores_invalid_filter_objects(self): + filtr = {"col": "col1", "op": "=="} + filters = [filtr] + col = DruidColumn(column_name="col1") + column_dict = {"col1": col} + self.assertIsNone(DruidDatasource.get_filters(filters, [], column_dict)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_in(self): + filtr = {"col": "A", "op": "in", "val": ["a", "b", "c"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIn("filter", res.filter) + self.assertIn("fields", res.filter["filter"]) + self.assertEqual("or", res.filter["filter"]["type"]) + self.assertEqual(3, len(res.filter["filter"]["fields"])) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_not_in(self): + filtr = {"col": "A", "op": "not in", "val": ["a", "b", "c"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIn("filter", res.filter) + self.assertIn("type", res.filter["filter"]) + self.assertEqual("not", res.filter["filter"]["type"]) + self.assertIn("field", res.filter["filter"]) + self.assertEqual( + 3, len(res.filter["filter"]["field"].filter["filter"]["fields"]) + ) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_equals(self): + filtr = {"col": "A", "op": "==", "val": "h"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("selector", res.filter["filter"]["type"]) + self.assertEqual("A", res.filter["filter"]["dimension"]) + self.assertEqual("h", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_not_equals(self): + filtr = {"col": "A", "op": "!=", "val": "h"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("not", res.filter["filter"]["type"]) + self.assertEqual("h", res.filter["filter"]["field"].filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_bounds_filter(self): + filtr = {"col": "A", "op": ">=", "val": "h"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertFalse(res.filter["filter"]["lowerStrict"]) + self.assertEqual("A", res.filter["filter"]["dimension"]) + self.assertEqual("h", res.filter["filter"]["lower"]) + self.assertEqual("lexicographic", res.filter["filter"]["ordering"]) + filtr["op"] = ">" + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertTrue(res.filter["filter"]["lowerStrict"]) + filtr["op"] = "<=" + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertFalse(res.filter["filter"]["upperStrict"]) + self.assertEqual("h", res.filter["filter"]["upper"]) + filtr["op"] = "<" + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertTrue(res.filter["filter"]["upperStrict"]) + filtr["val"] = 1 + res = DruidDatasource.get_filters([filtr], ["A"], column_dict) + self.assertEqual("numeric", res.filter["filter"]["ordering"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_is_null_filter(self): + filtr = {"col": "A", "op": "IS NULL"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("selector", res.filter["filter"]["type"]) + self.assertEqual("", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_is_not_null_filter(self): + filtr = {"col": "A", "op": "IS NOT NULL"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("not", res.filter["filter"]["type"]) + self.assertIn("field", res.filter["filter"]) + self.assertEqual( + "selector", res.filter["filter"]["field"].filter["filter"]["type"] + ) + self.assertEqual("", res.filter["filter"]["field"].filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_regex_filter(self): + filtr = {"col": "A", "op": "regex", "val": "[abc]"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("regex", res.filter["filter"]["type"]) + self.assertEqual("[abc]", res.filter["filter"]["pattern"]) + self.assertEqual("A", res.filter["filter"]["dimension"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_composes_multiple_filters(self): + filtr1 = {"col": "A", "op": "!=", "val": "y"} + filtr2 = {"col": "B", "op": "in", "val": ["a", "b", "c"]} + cola = DruidColumn(column_name="A") + colb = DruidColumn(column_name="B") + column_dict = {"A": cola, "B": colb} + res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) + self.assertEqual("and", res.filter["filter"]["type"]) + self.assertEqual(2, len(res.filter["filter"]["fields"])) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_ignores_in_not_in_with_empty_value(self): + filtr1 = {"col": "A", "op": "in", "val": []} + filtr2 = {"col": "A", "op": "not in", "val": []} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) + self.assertIsNone(res) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_equals_for_in_not_in_single_value(self): + filtr = {"col": "A", "op": "in", "val": ["a"]} + cola = DruidColumn(column_name="A") + colb = DruidColumn(column_name="B") + column_dict = {"A": cola, "B": colb} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("selector", res.filter["filter"]["type"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_handles_arrays_for_string_types(self): + filtr = {"col": "A", "op": "==", "val": ["a", "b"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("a", res.filter["filter"]["value"]) + + filtr = {"col": "A", "op": "==", "val": []} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIsNone(res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_handles_none_for_string_types(self): + filtr = {"col": "A", "op": "==", "val": None} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIsNone(res) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extracts_values_in_quotes(self): + filtr = {"col": "A", "op": "in", "val": ['"a"']} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("a", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_keeps_trailing_spaces(self): + filtr = {"col": "A", "op": "in", "val": ["a "]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("a ", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_converts_strings_to_num(self): + filtr = {"col": "A", "op": "in", "val": ["6"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], ["A"], column_dict) + self.assertEqual(6, res.filter["filter"]["value"]) + filtr = {"col": "A", "op": "==", "val": "6"} + res = DruidDatasource.get_filters([filtr], ["A"], column_dict) + self.assertEqual(6, res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_no_groupby(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + aggs = [] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) + columns = [] + metrics = ["metric1"] + ds.get_having_filters = Mock(return_value=[]) + client.query_builder = Mock() + client.query_builder.last_query = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + # no groupby calls client.timeseries + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + filter=[], + row_limit=100, + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(1, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args = client.timeseries.call_args_list[0][1] + self.assertNotIn("dimensions", called_args) + self.assertIn("post_aggregations", called_args) + # restore functions + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_with_adhoc_metric(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + all_metrics = [] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) + columns = [] + metrics = [ + { + "expressionType": "SIMPLE", + "column": {"type": "DOUBLE", "column_name": "col1"}, + "aggregate": "SUM", + "label": "My Adhoc Metric", + } + ] + + ds.get_having_filters = Mock(return_value=[]) + client.query_builder = Mock() + client.query_builder.last_query = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + # no groupby calls client.timeseries + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + filter=[], + row_limit=100, + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(1, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args = client.timeseries.call_args_list[0][1] + self.assertNotIn("dimensions", called_args) + self.assertIn("post_aggregations", called_args) + # restore functions + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_single_groupby(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + aggs = ["metric1"] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) + columns = ["col1"] + metrics = ["metric1"] + ds.get_having_filters = Mock(return_value=[]) + client.query_builder.last_query.query_dict = {"mock": 0} + # client.topn is called twice + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=100, + client=client, + order_desc=True, + filter=[], + ) + self.assertEqual(2, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args_pre = client.topn.call_args_list[0][1] + self.assertNotIn("dimensions", called_args_pre) + self.assertIn("dimension", called_args_pre) + called_args = client.topn.call_args_list[1][1] + self.assertIn("dimension", called_args) + self.assertEqual("col1", called_args["dimension"]) + # not order_desc + client = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + order_desc=False, + filter=[], + row_limit=100, + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(1, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + self.assertIn("dimensions", client.groupby.call_args_list[0][1]) + self.assertEqual(["col1"], client.groupby.call_args_list[0][1]["dimensions"]) + # order_desc but timeseries and dimension spec + # calls topn with single dimension spec 'dimension' + spec = {"outputName": "hello", "dimension": "matcho"} + spec_json = json.dumps(spec) + col3 = DruidColumn(column_name="col3", dimension_spec_json=spec_json) + ds.columns.append(col3) + groupby = ["col3"] + client = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=groupby, + client=client, + order_desc=True, + timeseries_limit=5, + filter=[], + row_limit=100, + ) + self.assertEqual(2, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + self.assertIn("dimension", client.topn.call_args_list[0][1]) + self.assertIn("dimension", client.topn.call_args_list[1][1]) + # uses dimension for pre query and full spec for final query + self.assertEqual("matcho", client.topn.call_args_list[0][1]["dimension"]) + self.assertEqual(spec, client.topn.call_args_list[1][1]["dimension"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_multiple_groupby(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + aggs = [] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) + columns = ["col1", "col2"] + metrics = ["metric1"] + ds.get_having_filters = Mock(return_value=[]) + client.query_builder = Mock() + client.query_builder.last_query = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + # no groupby calls client.timeseries + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + row_limit=100, + filter=[], + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(1, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args = client.groupby.call_args_list[0][1] + self.assertIn("dimensions", called_args) + self.assertEqual(["col1", "col2"], called_args["dimensions"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_post_agg_returns_correct_agg_type(self): + get_post_agg = DruidDatasource.get_post_agg + # javascript PostAggregators + function = "function(field1, field2) { return field1 + field2; }" + conf = { + "type": "javascript", + "name": "postagg_name", + "fieldNames": ["field1", "field2"], + "function": function, + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["type"], "javascript") + self.assertEqual(postagg.post_aggregator["fieldNames"], ["field1", "field2"]) + self.assertEqual(postagg.post_aggregator["name"], "postagg_name") + self.assertEqual(postagg.post_aggregator["function"], function) + # Quantile + conf = {"type": "quantile", "name": "postagg_name", "probability": "0.5"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Quantile)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["probability"], "0.5") + # Quantiles + conf = { + "type": "quantiles", + "name": "postagg_name", + "probabilities": "0.4,0.5,0.6", + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Quantiles)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["probabilities"], "0.4,0.5,0.6") + # FieldAccess + conf = {"type": "fieldAccess", "name": "field_name"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Field)) + self.assertEqual(postagg.name, "field_name") + # constant + conf = {"type": "constant", "value": 1234, "name": "postagg_name"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Const)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["value"], 1234) + # hyperUniqueCardinality + conf = {"type": "hyperUniqueCardinality", "name": "unique_name"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality)) + self.assertEqual(postagg.name, "unique_name") + # arithmetic + conf = { + "type": "arithmetic", + "fn": "+", + "fields": ["field1", "field2"], + "name": "postagg_name", + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Postaggregator)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["fn"], "+") + self.assertEqual(postagg.post_aggregator["fields"], ["field1", "field2"]) + # custom post aggregator + conf = {"type": "custom", "name": "custom_name", "stuff": "more_stuff"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, models.CustomPostAggregator)) + self.assertEqual(postagg.name, "custom_name") + self.assertEqual(postagg.post_aggregator["stuff"], "more_stuff") + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_find_postaggs_for_returns_postaggs_and_removes(self): + find_postaggs_for = DruidDatasource.find_postaggs_for + postagg_names = set(["pa2", "pa3", "pa4", "m1", "m2", "m3", "m4"]) + + metrics = {} + for i in range(1, 6): + emplace(metrics, "pa" + str(i), True) + emplace(metrics, "m" + str(i), False) + postagg_list = find_postaggs_for(postagg_names, metrics) + self.assertEqual(3, len(postagg_list)) + self.assertEqual(4, len(postagg_names)) + expected_metrics = ["m1", "m2", "m3", "m4"] + expected_postaggs = set(["pa2", "pa3", "pa4"]) + for postagg in postagg_list: + expected_postaggs.remove(postagg.metric_name) + for metric in expected_metrics: + postagg_names.remove(metric) + self.assertEqual(0, len(expected_postaggs)) + self.assertEqual(0, len(postagg_names)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_recursive_get_fields(self): + conf = { + "type": "quantile", + "fieldName": "f1", + "field": { + "type": "custom", + "fields": [ + {"type": "fieldAccess", "fieldName": "f2"}, + {"type": "fieldAccess", "fieldName": "f3"}, + { + "type": "quantiles", + "fieldName": "f4", + "field": {"type": "custom"}, + }, + { + "type": "custom", + "fields": [ + {"type": "fieldAccess", "fieldName": "f5"}, + { + "type": "fieldAccess", + "fieldName": "f2", + "fields": [ + {"type": "fieldAccess", "fieldName": "f3"}, + {"type": "fieldIgnoreMe", "fieldName": "f6"}, + ], + }, + ], + }, + ], + }, + } + fields = DruidDatasource.recursive_get_fields(conf) + expected = set(["f1", "f2", "f3", "f4", "f5"]) + self.assertEqual(5, len(fields)) + for field in fields: + expected.remove(field) + self.assertEqual(0, len(expected)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_metrics_and_post_aggs_tree(self): + metrics = ["A", "B", "m1", "m2"] + metrics_dict = {} + for i in range(ord("A"), ord("K") + 1): + emplace(metrics_dict, chr(i), True) + for i in range(1, 10): + emplace(metrics_dict, "m" + str(i), False) + + def depends_on(index, fields): + dependents = fields if isinstance(fields, list) else [fields] + metrics_dict[index].json_obj = {"fieldNames": dependents} + + depends_on("A", ["m1", "D", "C"]) + depends_on("B", ["B", "C", "E", "F", "m3"]) + depends_on("C", ["H", "I"]) + depends_on("D", ["m2", "m5", "G", "C"]) + depends_on("E", ["H", "I", "J"]) + depends_on("F", ["J", "m5"]) + depends_on("G", ["m4", "m7", "m6", "A"]) + depends_on("H", ["A", "m4", "I"]) + depends_on("I", ["H", "K"]) + depends_on("J", "K") + depends_on("K", ["m8", "m9"]) + aggs, postaggs = DruidDatasource.metrics_and_post_aggs(metrics, metrics_dict) + expected_metrics = set(aggs.keys()) + self.assertEqual(9, len(aggs)) + for i in range(1, 10): + expected_metrics.remove("m" + str(i)) + self.assertEqual(0, len(expected_metrics)) + self.assertEqual(11, len(postaggs)) + for i in range(ord("A"), ord("K") + 1): + del postaggs[chr(i)] + self.assertEqual(0, len(postaggs)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_metrics_and_post_aggs(self): + """ + Test generation of metrics and post-aggregations from an initial list + of superset metrics (which may include the results of either). This + primarily tests that specifying a post-aggregator metric will also + require the raw aggregation of the associated druid metric column. + """ + metrics_dict = { + "unused_count": DruidMetric( + metric_name="unused_count", + verbose_name="COUNT(*)", + metric_type="count", + json=json.dumps({"type": "count", "name": "unused_count"}), + ), + "some_sum": DruidMetric( + metric_name="some_sum", + verbose_name="SUM(*)", + metric_type="sum", + json=json.dumps({"type": "sum", "name": "sum"}), + ), + "a_histogram": DruidMetric( + metric_name="a_histogram", + verbose_name="APPROXIMATE_HISTOGRAM(*)", + metric_type="approxHistogramFold", + json=json.dumps({"type": "approxHistogramFold", "name": "a_histogram"}), + ), + "aCustomMetric": DruidMetric( + metric_name="aCustomMetric", + verbose_name="MY_AWESOME_METRIC(*)", + metric_type="aCustomType", + json=json.dumps({"type": "customMetric", "name": "aCustomMetric"}), + ), + "quantile_p95": DruidMetric( + metric_name="quantile_p95", + verbose_name="P95(*)", + metric_type="postagg", + json=json.dumps( + { + "type": "quantile", + "probability": 0.95, + "name": "p95", + "fieldName": "a_histogram", + } + ), + ), + "aCustomPostAgg": DruidMetric( + metric_name="aCustomPostAgg", + verbose_name="CUSTOM_POST_AGG(*)", + metric_type="postagg", + json=json.dumps( + { + "type": "customPostAgg", + "name": "aCustomPostAgg", + "field": {"type": "fieldAccess", "fieldName": "aCustomMetric"}, + } + ), + ), + } + + adhoc_metric = { + "expressionType": "SIMPLE", + "column": {"type": "DOUBLE", "column_name": "value"}, + "aggregate": "SUM", + "label": "My Adhoc Metric", + } + + metrics = ["some_sum"] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + assert set(saved_metrics.keys()) == {"some_sum"} + assert post_aggs == {} + + metrics = [adhoc_metric] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + assert set(saved_metrics.keys()) == set([adhoc_metric["label"]]) + assert post_aggs == {} + + metrics = ["some_sum", adhoc_metric] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + assert set(saved_metrics.keys()) == {"some_sum", adhoc_metric["label"]} + assert post_aggs == {} + + metrics = ["quantile_p95"] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + result_postaggs = set(["quantile_p95"]) + assert set(saved_metrics.keys()) == {"a_histogram"} + assert set(post_aggs.keys()) == result_postaggs + + metrics = ["aCustomPostAgg"] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + result_postaggs = set(["aCustomPostAgg"]) + assert set(saved_metrics.keys()) == {"aCustomMetric"} + assert set(post_aggs.keys()) == result_postaggs + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_druid_type_from_adhoc_metric(self): + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "DOUBLE", "column_name": "value"}, + "aggregate": "SUM", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "doubleSum" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "LONG", "column_name": "value"}, + "aggregate": "MAX", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "longMax" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "VARCHAR(255)", "column_name": "value"}, + "aggregate": "COUNT", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "count" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "VARCHAR(255)", "column_name": "value"}, + "aggregate": "COUNT_DISTINCT", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "cardinality" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "hyperUnique", "column_name": "value"}, + "aggregate": "COUNT_DISTINCT", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "hyperUnique" + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_order_by_metrics(self): + client = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + from_dttm = Mock() + to_dttm = Mock() + ds = DruidDatasource(datasource_name="datasource") + ds.get_having_filters = Mock(return_value=[]) + dim1 = DruidColumn(column_name="dim1") + dim2 = DruidColumn(column_name="dim2") + metrics_dict = { + "count1": DruidMetric( + metric_name="count1", + metric_type="count", + json=json.dumps({"type": "count", "name": "count1"}), + ), + "sum1": DruidMetric( + metric_name="sum1", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum1"}), + ), + "sum2": DruidMetric( + metric_name="sum2", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum2"}), + ), + "div1": DruidMetric( + metric_name="div1", + metric_type="postagg", + json=json.dumps( + { + "fn": "/", + "type": "arithmetic", + "name": "div1", + "fields": [ + {"fieldName": "sum1", "type": "fieldAccess"}, + {"fieldName": "sum2", "type": "fieldAccess"}, + ], + } + ), + ), + } + ds.columns = [dim1, dim2] + ds.metrics = list(metrics_dict.values()) + + columns = ["dim1"] + metrics = ["count1"] + granularity = "all" + # get the counts of the top 5 'dim1's, order by 'sum1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="sum1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.topn.call_args_list[0][1] + self.assertEqual("dim1", qry_obj["dimension"]) + self.assertEqual("sum1", qry_obj["metric"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) + self.assertEqual(set(), set(post_aggregations.keys())) + + # get the counts of the top 5 'dim1's, order by 'div1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="div1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.topn.call_args_list[1][1] + self.assertEqual("dim1", qry_obj["dimension"]) + self.assertEqual("div1", qry_obj["metric"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) + self.assertEqual({"div1"}, set(post_aggregations.keys())) + + columns = ["dim1", "dim2"] + # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="sum1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.groupby.call_args_list[0][1] + self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) + self.assertEqual("sum1", qry_obj["limit_spec"]["columns"][0]["dimension"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) + self.assertEqual(set(), set(post_aggregations.keys())) + + # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="div1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.groupby.call_args_list[1][1] + self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) + self.assertEqual("div1", qry_obj["limit_spec"]["columns"][0]["dimension"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) + self.assertEqual({"div1"}, set(post_aggregations.keys())) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_aggregations(self): + ds = DruidDatasource(datasource_name="datasource") + metrics_dict = { + "sum1": DruidMetric( + metric_name="sum1", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum1"}), + ), + "sum2": DruidMetric( + metric_name="sum2", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum2"}), + ), + "div1": DruidMetric( + metric_name="div1", + metric_type="postagg", + json=json.dumps( + { + "fn": "/", + "type": "arithmetic", + "name": "div1", + "fields": [ + {"fieldName": "sum1", "type": "fieldAccess"}, + {"fieldName": "sum2", "type": "fieldAccess"}, + ], + } + ), + ), + } + metric_names = ["sum1", "sum2"] + aggs = ds.get_aggregations(metrics_dict, metric_names) + expected_agg = {name: metrics_dict[name].json_obj for name in metric_names} + self.assertEqual(expected_agg, aggs) + + metric_names = ["sum1", "col1"] + self.assertRaises( + SupersetException, ds.get_aggregations, metrics_dict, metric_names + ) + + metric_names = ["sum1", "div1"] + self.assertRaises( + SupersetException, ds.get_aggregations, metrics_dict, metric_names + ) diff --git a/tests/model_tests.py b/tests/model_tests.py index 52378eb2bfe1..4a203a96475d 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -234,10 +234,11 @@ def query_with_expr_helper(self, is_timeseries, inner_join=True): label="arbitrary", expressionType="SQL", sqlExpression="COUNT(1)" ) query_obj = dict( - columns=[arbitrary_gby, "name"], + groupby=[arbitrary_gby, "name"], metrics=[arbitrary_metric], filter=[], is_timeseries=is_timeseries, + columns=[], granularity="ds", from_dttm=None, to_dttm=None, @@ -277,10 +278,11 @@ def test_query_with_expr_groupby(self): def test_sql_mutator(self): tbl = self.get_table_by_name("birth_names") query_obj = dict( - columns=["name"], + groupby=[], metrics=[], filter=[], is_timeseries=False, + columns=["name"], granularity=None, from_dttm=None, to_dttm=None, diff --git a/tests/security_tests.py b/tests/security_tests.py index 183a059d97ac..7b8df262fea4 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -853,6 +853,7 @@ def test_rls_filter_alters_query(self): ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("birth_names") query_obj = dict( + groupby=[], metrics=[], filter=[], is_timeseries=False, @@ -871,6 +872,7 @@ def test_rls_filter_doesnt_alter_query(self): ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("birth_names") query_obj = dict( + groupby=[], metrics=[], filter=[], is_timeseries=False, diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 85296d09fe85..700c2e2bb94a 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -79,7 +79,7 @@ def test_has_extra_cache_keys(self): "granularity": None, "from_dttm": None, "to_dttm": None, - "columns": ["user"], + "groupby": ["user"], "metrics": [], "is_timeseries": False, "filter": [], @@ -100,7 +100,7 @@ def test_has_no_extra_cache_keys(self): "granularity": None, "from_dttm": None, "to_dttm": None, - "columns": ["user"], + "groupby": ["user"], "metrics": [], "is_timeseries": False, "filter": [], diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 3f48696f6e8f..ab10f4970665 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -373,15 +373,21 @@ def test_query_obj_throws_columns_and_metrics(self, super_query_obj): with self.assertRaises(Exception): test_viz.query_obj() - def test_query_obj_merges_all_columns(self): + @patch("superset.viz.BaseViz.query_obj") + def test_query_obj_merges_all_columns(self, super_query_obj): datasource = self.get_datasource_mock() form_data = { "all_columns": ["colA", "colB", "colC"], "order_by_cols": ['["colA", "colB"]', '["colC"]'], } + super_query_obj.return_value = { + "columns": ["colD", "colC"], + "groupby": ["colA", "colB"], + } test_viz = viz.TableViz(datasource, form_data) query_obj = test_viz.query_obj() self.assertEqual(form_data["all_columns"], query_obj["columns"]) + self.assertEqual([], query_obj["groupby"]) self.assertEqual([["colA", "colB"], ["colC"]], query_obj["orderby"]) @patch("superset.viz.BaseViz.query_obj") @@ -942,6 +948,18 @@ def test_query_obj_throws_metrics_and_groupby(self, super_query_obj): class BaseDeckGLVizTestCase(SupersetTestCase): + def test_get_metrics(self): + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) + result = test_viz_deckgl.get_metrics() + assert result == [form_data.get("size")] + + form_data = {} + test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) + result = test_viz_deckgl.get_metrics() + assert result == [] + def test_scatterviz_get_metrics(self): form_data = load_fixture("deck_path_form_data.json") datasource = self.get_datasource_mock() @@ -978,6 +996,36 @@ def test_get_properties(self): self.assertTrue("" in str(context.exception)) + def test_process_spatial_query_obj(self): + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + mock_key = "spatial_key" + mock_gb = [] + test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) + + with self.assertRaises(ValueError) as context: + test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb) + + self.assertTrue("Bad spatial key" in str(context.exception)) + + test_form_data = { + "latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"}, + "delimited_key": {"type": "delimited", "lonlatCol": "lonlat"}, + "geohash_key": {"type": "geohash", "geohashCol": "geo"}, + } + + datasource = self.get_datasource_mock() + expected_results = { + "latlong_key": ["lon", "lat"], + "delimited_key": ["lonlat"], + "geohash_key": ["geo"], + } + for mock_key in ["latlong_key", "delimited_key", "geohash_key"]: + mock_gb = [] + test_viz_deckgl = viz.BaseDeckGLViz(datasource, test_form_data) + test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb) + assert expected_results.get(mock_key) == mock_gb + def test_geojson_query_obj(self): form_data = load_fixture("deck_geojson_form_data.json") datasource = self.get_datasource_mock() @@ -985,7 +1033,8 @@ def test_geojson_query_obj(self): results = test_viz_deckgl.query_obj() assert results["metrics"] == [] - assert results["columns"] == ["color", "test_col"] + assert results["groupby"] == [] + assert results["columns"] == ["test_col"] def test_parse_coordinates(self): form_data = load_fixture("deck_path_form_data.json") @@ -1013,6 +1062,67 @@ def test_parse_coordinates_raises(self): with self.assertRaises(SpatialException): test_viz_deckgl.parse_coordinates("fldkjsalkj,fdlaskjfjadlksj") + @patch("superset.utils.core.uuid.uuid4") + def test_filter_nulls(self, mock_uuid4): + mock_uuid4.return_value = uuid.UUID("12345678123456781234567812345678") + test_form_data = { + "latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"}, + "delimited_key": {"type": "delimited", "lonlatCol": "lonlat"}, + "geohash_key": {"type": "geohash", "geohashCol": "geo"}, + } + + datasource = self.get_datasource_mock() + expected_results = { + "latlong_key": [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "filterOptionName": "12345678-1234-5678-1234-567812345678", + "comparator": "", + "operator": "IS NOT NULL", + "subject": "lat", + "isExtra": False, + }, + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "filterOptionName": "12345678-1234-5678-1234-567812345678", + "comparator": "", + "operator": "IS NOT NULL", + "subject": "lon", + "isExtra": False, + }, + ], + "delimited_key": [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "filterOptionName": "12345678-1234-5678-1234-567812345678", + "comparator": "", + "operator": "IS NOT NULL", + "subject": "lonlat", + "isExtra": False, + } + ], + "geohash_key": [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "filterOptionName": "12345678-1234-5678-1234-567812345678", + "comparator": "", + "operator": "IS NOT NULL", + "subject": "geo", + "isExtra": False, + } + ], + } + for mock_key in ["latlong_key", "delimited_key", "geohash_key"]: + test_viz_deckgl = viz.BaseDeckGLViz(datasource, test_form_data.copy()) + test_viz_deckgl.spatial_control_keys = [mock_key] + test_viz_deckgl.add_null_filters() + adhoc_filters = test_viz_deckgl.form_data["adhoc_filters"] + assert expected_results.get(mock_key) == adhoc_filters + class TimeSeriesVizTestCase(SupersetTestCase): def test_timeseries_unicode_data(self): diff --git a/tests/viz_tests_sip38.py b/tests/viz_tests_sip38.py new file mode 100644 index 000000000000..4f1e5bbf719b --- /dev/null +++ b/tests/viz_tests_sip38.py @@ -0,0 +1,1175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +from datetime import datetime +import logging +from math import nan +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd + +import tests.test_app +import superset.viz_sip38 as viz +from superset import app +from superset.constants import NULL_STRING +from superset.exceptions import SpatialException +from superset.utils.core import DTTM_ALIAS + +from .base_tests import SupersetTestCase +from .utils import load_fixture + +logger = logging.getLogger(__name__) + + +@patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"SIP_38_VIZ_REARCHITECTURE": True}, + clear=True, +) +class Sip38TestCase(SupersetTestCase): + pass + + +class BaseVizTestCase(Sip38TestCase): + def test_constructor_exception_no_datasource(self): + form_data = {} + datasource = None + with self.assertRaises(Exception): + viz.BaseViz(datasource, form_data) + + def test_process_metrics(self): + # test TableViz metrics in correct order + form_data = { + "url_params": {}, + "row_limit": 500, + "metric": "sum__SP_POP_TOTL", + "entity": "country_code", + "secondary_metric": "sum__SP_POP_TOTL", + "granularity_sqla": "year", + "page_length": 0, + "all_columns": [], + "viz_type": "table", + "since": "2014-01-01", + "until": "2014-01-02", + "metrics": ["sum__SP_POP_TOTL", "SUM(SE_PRM_NENR_MA)", "SUM(SP_URB_TOTL)"], + "country_fieldtype": "cca3", + "percent_metrics": ["count"], + "slice_id": 74, + "time_grain_sqla": None, + "order_by_cols": [], + "groupby": ["country_name"], + "compare_lag": "10", + "limit": "25", + "datasource": "2__table", + "table_timestamp_format": "%Y-%m-%d %H:%M:%S", + "markup_type": "markdown", + "where": "", + "compare_suffix": "o10Y", + } + datasource = Mock() + datasource.type = "table" + test_viz = viz.BaseViz(datasource, form_data) + expect_metric_labels = [ + u"sum__SP_POP_TOTL", + u"SUM(SE_PRM_NENR_MA)", + u"SUM(SP_URB_TOTL)", + u"count", + ] + self.assertEqual(test_viz.metric_labels, expect_metric_labels) + self.assertEqual(test_viz.all_metrics, expect_metric_labels) + + def test_get_df_returns_empty_df(self): + form_data = {"dummy": 123} + query_obj = {"granularity": "day"} + datasource = self.get_datasource_mock() + test_viz = viz.BaseViz(datasource, form_data) + result = test_viz.get_df(query_obj) + self.assertEqual(type(result), pd.DataFrame) + self.assertTrue(result.empty) + + def test_get_df_handles_dttm_col(self): + form_data = {"dummy": 123} + query_obj = {"granularity": "day"} + results = Mock() + results.query = Mock() + results.status = Mock() + results.error_message = Mock() + datasource = Mock() + datasource.type = "table" + datasource.query = Mock(return_value=results) + mock_dttm_col = Mock() + datasource.get_column = Mock(return_value=mock_dttm_col) + + test_viz = viz.BaseViz(datasource, form_data) + test_viz.df_metrics_to_num = Mock() + test_viz.get_fillna_for_columns = Mock(return_value=0) + + results.df = pd.DataFrame(data={DTTM_ALIAS: ["1960-01-01 05:00:00"]}) + datasource.offset = 0 + mock_dttm_col = Mock() + datasource.get_column = Mock(return_value=mock_dttm_col) + mock_dttm_col.python_date_format = "epoch_ms" + result = test_viz.get_df(query_obj) + import logging + + logger.info(result) + pd.testing.assert_series_equal( + result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 5, 0)], name=DTTM_ALIAS) + ) + + mock_dttm_col.python_date_format = None + result = test_viz.get_df(query_obj) + pd.testing.assert_series_equal( + result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 5, 0)], name=DTTM_ALIAS) + ) + + datasource.offset = 1 + result = test_viz.get_df(query_obj) + pd.testing.assert_series_equal( + result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 6, 0)], name=DTTM_ALIAS) + ) + + datasource.offset = 0 + results.df = pd.DataFrame(data={DTTM_ALIAS: ["1960-01-01"]}) + mock_dttm_col.python_date_format = "%Y-%m-%d" + result = test_viz.get_df(query_obj) + pd.testing.assert_series_equal( + result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 0, 0)], name=DTTM_ALIAS) + ) + + def test_cache_timeout(self): + datasource = self.get_datasource_mock() + datasource.cache_timeout = 0 + test_viz = viz.BaseViz(datasource, form_data={}) + self.assertEqual(0, test_viz.cache_timeout) + + datasource.cache_timeout = 156 + test_viz = viz.BaseViz(datasource, form_data={}) + self.assertEqual(156, test_viz.cache_timeout) + + datasource.cache_timeout = None + datasource.database.cache_timeout = 0 + self.assertEqual(0, test_viz.cache_timeout) + + datasource.database.cache_timeout = 1666 + self.assertEqual(1666, test_viz.cache_timeout) + + datasource.database.cache_timeout = None + test_viz = viz.BaseViz(datasource, form_data={}) + self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout) + + +class TableVizTestCase(Sip38TestCase): + def test_get_data_applies_percentage(self): + form_data = { + "groupby": ["groupA", "groupB"], + "metrics": [ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "label": "SUM(value1)", + "column": {"column_name": "value1", "type": "DOUBLE"}, + }, + "count", + "avg__C", + ], + "percent_metrics": [ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "label": "SUM(value1)", + "column": {"column_name": "value1", "type": "DOUBLE"}, + }, + "avg__B", + ], + } + datasource = self.get_datasource_mock() + + df = pd.DataFrame( + { + "SUM(value1)": [15, 20, 25, 40], + "avg__B": [10, 20, 5, 15], + "avg__C": [11, 22, 33, 44], + "count": [6, 7, 8, 9], + "groupA": ["A", "B", "C", "C"], + "groupB": ["x", "x", "y", "z"], + } + ) + + test_viz = viz.TableViz(datasource, form_data) + data = test_viz.get_data(df) + # Check method correctly transforms data and computes percents + self.assertEqual( + [ + "groupA", + "groupB", + "SUM(value1)", + "count", + "avg__C", + "%SUM(value1)", + "%avg__B", + ], + list(data["columns"]), + ) + expected = [ + { + "groupA": "A", + "groupB": "x", + "SUM(value1)": 15, + "count": 6, + "avg__C": 11, + "%SUM(value1)": 0.15, + "%avg__B": 0.2, + }, + { + "groupA": "B", + "groupB": "x", + "SUM(value1)": 20, + "count": 7, + "avg__C": 22, + "%SUM(value1)": 0.2, + "%avg__B": 0.4, + }, + { + "groupA": "C", + "groupB": "y", + "SUM(value1)": 25, + "count": 8, + "avg__C": 33, + "%SUM(value1)": 0.25, + "%avg__B": 0.1, + }, + { + "groupA": "C", + "groupB": "z", + "SUM(value1)": 40, + "count": 9, + "avg__C": 44, + "%SUM(value1)": 0.4, + "%avg__B": 0.3, + }, + ] + self.assertEqual(expected, data["records"]) + + def test_parse_adhoc_filters(self): + form_data = { + "metrics": [ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "label": "SUM(value1)", + "column": {"column_name": "value1", "type": "DOUBLE"}, + } + ], + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "clause": "WHERE", + "subject": "value2", + "operator": ">", + "comparator": "100", + }, + { + "expressionType": "SIMPLE", + "clause": "HAVING", + "subject": "SUM(value1)", + "operator": "<", + "comparator": "10", + }, + { + "expressionType": "SQL", + "clause": "HAVING", + "sqlExpression": "SUM(value1) > 5", + }, + { + "expressionType": "SQL", + "clause": "WHERE", + "sqlExpression": "value3 in ('North America')", + }, + ], + } + datasource = self.get_datasource_mock() + test_viz = viz.TableViz(datasource, form_data) + query_obj = test_viz.query_obj() + self.assertEqual( + [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"] + ) + self.assertEqual( + [{"op": "<", "val": "10", "col": "SUM(value1)"}], + query_obj["extras"]["having_druid"], + ) + self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"]) + self.assertEqual("(SUM(value1) > 5)", query_obj["extras"]["having"]) + + def test_adhoc_filters_overwrite_legacy_filters(self): + form_data = { + "metrics": [ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "label": "SUM(value1)", + "column": {"column_name": "value1", "type": "DOUBLE"}, + } + ], + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "clause": "WHERE", + "subject": "value2", + "operator": ">", + "comparator": "100", + }, + { + "expressionType": "SQL", + "clause": "WHERE", + "sqlExpression": "value3 in ('North America')", + }, + ], + "having": "SUM(value1) > 5", + } + datasource = self.get_datasource_mock() + test_viz = viz.TableViz(datasource, form_data) + query_obj = test_viz.query_obj() + self.assertEqual( + [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"] + ) + self.assertEqual([], query_obj["extras"]["having_druid"]) + self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"]) + self.assertEqual("", query_obj["extras"]["having"]) + + @patch("superset.viz_sip38.BaseViz.query_obj") + def test_query_obj_merges_percent_metrics(self, super_query_obj): + datasource = self.get_datasource_mock() + form_data = { + "percent_metrics": ["sum__A", "avg__B", "max__Y"], + "metrics": ["sum__A", "count", "avg__C"], + } + test_viz = viz.TableViz(datasource, form_data) + f_query_obj = {"metrics": form_data["metrics"]} + super_query_obj.return_value = f_query_obj + query_obj = test_viz.query_obj() + self.assertEqual( + ["sum__A", "count", "avg__C", "avg__B", "max__Y"], query_obj["metrics"] + ) + + @patch("superset.viz_sip38.BaseViz.query_obj") + def test_query_obj_throws_columns_and_metrics(self, super_query_obj): + datasource = self.get_datasource_mock() + form_data = {"all_columns": ["A", "B"], "metrics": ["x", "y"]} + super_query_obj.return_value = {} + test_viz = viz.TableViz(datasource, form_data) + with self.assertRaises(Exception): + test_viz.query_obj() + del form_data["metrics"] + form_data["groupby"] = ["B", "C"] + test_viz = viz.TableViz(datasource, form_data) + with self.assertRaises(Exception): + test_viz.query_obj() + + def test_query_obj_merges_all_columns(self): + datasource = self.get_datasource_mock() + form_data = { + "all_columns": ["colA", "colB", "colC"], + "order_by_cols": ['["colA", "colB"]', '["colC"]'], + } + test_viz = viz.TableViz(datasource, form_data) + query_obj = test_viz.query_obj() + self.assertEqual(form_data["all_columns"], query_obj["columns"]) + self.assertEqual([["colA", "colB"], ["colC"]], query_obj["orderby"]) + + @patch("superset.viz_sip38.BaseViz.query_obj") + def test_query_obj_uses_sortby(self, super_query_obj): + datasource = self.get_datasource_mock() + form_data = {"timeseries_limit_metric": "__time__", "order_desc": False} + super_query_obj.return_value = {"metrics": ["colA", "colB"]} + test_viz = viz.TableViz(datasource, form_data) + query_obj = test_viz.query_obj() + self.assertEqual(["colA", "colB", "__time__"], query_obj["metrics"]) + self.assertEqual([("__time__", True)], query_obj["orderby"]) + + def test_should_be_timeseries_raises_when_no_granularity(self): + datasource = self.get_datasource_mock() + form_data = {"include_time": True} + test_viz = viz.TableViz(datasource, form_data) + with self.assertRaises(Exception): + test_viz.should_be_timeseries() + + def test_adhoc_metric_with_sortby(self): + metrics = [ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "label": "sum_value", + "column": {"column_name": "value1", "type": "DOUBLE"}, + } + ] + form_data = { + "metrics": metrics, + "timeseries_limit_metric": { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "label": "SUM(value1)", + "column": {"column_name": "value1", "type": "DOUBLE"}, + }, + "order_desc": False, + } + + df = pd.DataFrame({"SUM(value1)": [15], "sum_value": [15]}) + datasource = self.get_datasource_mock() + test_viz = viz.TableViz(datasource, form_data) + data = test_viz.get_data(df) + self.assertEqual(["sum_value", "SUM(value1)"], data["columns"]) + + +class DistBarVizTestCase(Sip38TestCase): + def test_groupby_nulls(self): + form_data = { + "metrics": ["votes"], + "adhoc_filters": [], + "groupby": ["toppings"], + "columns": [], + } + datasource = self.get_datasource_mock() + df = pd.DataFrame( + { + "toppings": ["cheese", "pepperoni", "anchovies", None], + "votes": [3, 5, 1, 2], + } + ) + test_viz = viz.DistributionBarViz(datasource, form_data) + data = test_viz.get_data(df)[0] + self.assertEqual("votes", data["key"]) + expected_values = [ + {"x": "pepperoni", "y": 5}, + {"x": "cheese", "y": 3}, + {"x": NULL_STRING, "y": 2}, + {"x": "anchovies", "y": 1}, + ] + self.assertEqual(expected_values, data["values"]) + + def test_groupby_nans(self): + form_data = { + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["beds"], + "columns": [], + } + datasource = self.get_datasource_mock() + df = pd.DataFrame({"beds": [0, 1, nan, 2], "count": [30, 42, 3, 29]}) + test_viz = viz.DistributionBarViz(datasource, form_data) + data = test_viz.get_data(df)[0] + self.assertEqual("count", data["key"]) + expected_values = [ + {"x": "1.0", "y": 42}, + {"x": "0.0", "y": 30}, + {"x": "2.0", "y": 29}, + {"x": NULL_STRING, "y": 3}, + ] + self.assertEqual(expected_values, data["values"]) + + def test_column_nulls(self): + form_data = { + "metrics": ["votes"], + "adhoc_filters": [], + "groupby": ["toppings"], + "columns": ["role"], + } + datasource = self.get_datasource_mock() + df = pd.DataFrame( + { + "toppings": ["cheese", "pepperoni", "cheese", "pepperoni"], + "role": ["engineer", "engineer", None, None], + "votes": [3, 5, 1, 2], + } + ) + test_viz = viz.DistributionBarViz(datasource, form_data) + data = test_viz.get_data(df) + expected = [ + { + "key": NULL_STRING, + "values": [{"x": "pepperoni", "y": 2}, {"x": "cheese", "y": 1}], + }, + { + "key": "engineer", + "values": [{"x": "pepperoni", "y": 5}, {"x": "cheese", "y": 3}], + }, + ] + self.assertEqual(expected, data) + + +class PairedTTestTestCase(Sip38TestCase): + def test_get_data_transforms_dataframe(self): + form_data = { + "groupby": ["groupA", "groupB", "groupC"], + "metrics": ["metric1", "metric2", "metric3"], + } + datasource = self.get_datasource_mock() + # Test data + raw = {} + raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] + df = pd.DataFrame(raw) + pairedTTestViz = viz.viz_types["paired_ttest"](datasource, form_data) + data = pairedTTestViz.get_data(df) + # Check method correctly transforms data + expected = { + "metric1": [ + { + "values": [ + {"x": 100, "y": 1}, + {"x": 200, "y": 2}, + {"x": 300, "y": 3}, + ], + "group": ("a1", "a2", "a3"), + }, + { + "values": [ + {"x": 100, "y": 4}, + {"x": 200, "y": 5}, + {"x": 300, "y": 6}, + ], + "group": ("b1", "b2", "b3"), + }, + { + "values": [ + {"x": 100, "y": 7}, + {"x": 200, "y": 8}, + {"x": 300, "y": 9}, + ], + "group": ("c1", "c2", "c3"), + }, + ], + "metric2": [ + { + "values": [ + {"x": 100, "y": 10}, + {"x": 200, "y": 20}, + {"x": 300, "y": 30}, + ], + "group": ("a1", "a2", "a3"), + }, + { + "values": [ + {"x": 100, "y": 40}, + {"x": 200, "y": 50}, + {"x": 300, "y": 60}, + ], + "group": ("b1", "b2", "b3"), + }, + { + "values": [ + {"x": 100, "y": 70}, + {"x": 200, "y": 80}, + {"x": 300, "y": 90}, + ], + "group": ("c1", "c2", "c3"), + }, + ], + "metric3": [ + { + "values": [ + {"x": 100, "y": 100}, + {"x": 200, "y": 200}, + {"x": 300, "y": 300}, + ], + "group": ("a1", "a2", "a3"), + }, + { + "values": [ + {"x": 100, "y": 400}, + {"x": 200, "y": 500}, + {"x": 300, "y": 600}, + ], + "group": ("b1", "b2", "b3"), + }, + { + "values": [ + {"x": 100, "y": 700}, + {"x": 200, "y": 800}, + {"x": 300, "y": 900}, + ], + "group": ("c1", "c2", "c3"), + }, + ], + } + self.assertEqual(data, expected) + + def test_get_data_empty_null_keys(self): + form_data = {"groupby": [], "metrics": ["", None]} + datasource = self.get_datasource_mock() + # Test data + raw = {} + raw[DTTM_ALIAS] = [100, 200, 300] + raw[""] = [1, 2, 3] + raw[None] = [10, 20, 30] + + df = pd.DataFrame(raw) + pairedTTestViz = viz.viz_types["paired_ttest"](datasource, form_data) + data = pairedTTestViz.get_data(df) + # Check method correctly transforms data + expected = { + "N/A": [ + { + "values": [ + {"x": 100, "y": 1}, + {"x": 200, "y": 2}, + {"x": 300, "y": 3}, + ], + "group": "All", + } + ], + "NULL": [ + { + "values": [ + {"x": 100, "y": 10}, + {"x": 200, "y": 20}, + {"x": 300, "y": 30}, + ], + "group": "All", + } + ], + } + self.assertEqual(data, expected) + + +class PartitionVizTestCase(Sip38TestCase): + @patch("superset.viz_sip38.BaseViz.query_obj") + def test_query_obj_time_series_option(self, super_query_obj): + datasource = self.get_datasource_mock() + form_data = {} + test_viz = viz.PartitionViz(datasource, form_data) + super_query_obj.return_value = {} + query_obj = test_viz.query_obj() + self.assertFalse(query_obj["is_timeseries"]) + test_viz.form_data["time_series_option"] = "agg_sum" + query_obj = test_viz.query_obj() + self.assertTrue(query_obj["is_timeseries"]) + + def test_levels_for_computes_levels(self): + raw = {} + raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] + df = pd.DataFrame(raw) + groups = ["groupA", "groupB", "groupC"] + time_op = "agg_sum" + test_viz = viz.PartitionViz(Mock(), {}) + levels = test_viz.levels_for(time_op, groups, df) + self.assertEqual(4, len(levels)) + expected = {DTTM_ALIAS: 1800, "metric1": 45, "metric2": 450, "metric3": 4500} + self.assertEqual(expected, levels[0].to_dict()) + expected = { + DTTM_ALIAS: {"a1": 600, "b1": 600, "c1": 600}, + "metric1": {"a1": 6, "b1": 15, "c1": 24}, + "metric2": {"a1": 60, "b1": 150, "c1": 240}, + "metric3": {"a1": 600, "b1": 1500, "c1": 2400}, + } + self.assertEqual(expected, levels[1].to_dict()) + self.assertEqual(["groupA", "groupB"], levels[2].index.names) + self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) + time_op = "agg_mean" + levels = test_viz.levels_for(time_op, groups, df) + self.assertEqual(4, len(levels)) + expected = { + DTTM_ALIAS: 200.0, + "metric1": 5.0, + "metric2": 50.0, + "metric3": 500.0, + } + self.assertEqual(expected, levels[0].to_dict()) + expected = { + DTTM_ALIAS: {"a1": 200, "c1": 200, "b1": 200}, + "metric1": {"a1": 2, "b1": 5, "c1": 8}, + "metric2": {"a1": 20, "b1": 50, "c1": 80}, + "metric3": {"a1": 200, "b1": 500, "c1": 800}, + } + self.assertEqual(expected, levels[1].to_dict()) + self.assertEqual(["groupA", "groupB"], levels[2].index.names) + self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) + + def test_levels_for_diff_computes_difference(self): + raw = {} + raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] + df = pd.DataFrame(raw) + groups = ["groupA", "groupB", "groupC"] + test_viz = viz.PartitionViz(Mock(), {}) + time_op = "point_diff" + levels = test_viz.levels_for_diff(time_op, groups, df) + expected = {"metric1": 6, "metric2": 60, "metric3": 600} + self.assertEqual(expected, levels[0].to_dict()) + expected = { + "metric1": {"a1": 2, "b1": 2, "c1": 2}, + "metric2": {"a1": 20, "b1": 20, "c1": 20}, + "metric3": {"a1": 200, "b1": 200, "c1": 200}, + } + self.assertEqual(expected, levels[1].to_dict()) + self.assertEqual(4, len(levels)) + self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) + + def test_levels_for_time_calls_process_data_and_drops_cols(self): + raw = {} + raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] + df = pd.DataFrame(raw) + groups = ["groupA", "groupB", "groupC"] + test_viz = viz.PartitionViz(Mock(), {"groupby": groups}) + + def return_args(df_drop, aggregate): + return df_drop + + test_viz.process_data = Mock(side_effect=return_args) + levels = test_viz.levels_for_time(groups, df) + self.assertEqual(4, len(levels)) + cols = [DTTM_ALIAS, "metric1", "metric2", "metric3"] + self.assertEqual(sorted(cols), sorted(levels[0].columns.tolist())) + cols += ["groupA"] + self.assertEqual(sorted(cols), sorted(levels[1].columns.tolist())) + cols += ["groupB"] + self.assertEqual(sorted(cols), sorted(levels[2].columns.tolist())) + cols += ["groupC"] + self.assertEqual(sorted(cols), sorted(levels[3].columns.tolist())) + self.assertEqual(4, len(test_viz.process_data.mock_calls)) + + def test_nest_values_returns_hierarchy(self): + raw = {} + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] + df = pd.DataFrame(raw) + test_viz = viz.PartitionViz(Mock(), {}) + groups = ["groupA", "groupB", "groupC"] + levels = test_viz.levels_for("agg_sum", groups, df) + nest = test_viz.nest_values(levels) + self.assertEqual(3, len(nest)) + for i in range(0, 3): + self.assertEqual("metric" + str(i + 1), nest[i]["name"]) + self.assertEqual(3, len(nest[0]["children"])) + self.assertEqual(1, len(nest[0]["children"][0]["children"])) + self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"])) + + def test_nest_procs_returns_hierarchy(self): + raw = {} + raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] + df = pd.DataFrame(raw) + test_viz = viz.PartitionViz(Mock(), {}) + groups = ["groupA", "groupB", "groupC"] + metrics = ["metric1", "metric2", "metric3"] + procs = {} + for i in range(0, 4): + df_drop = df.drop(groups[i:], 1) + pivot = df_drop.pivot_table( + index=DTTM_ALIAS, columns=groups[:i], values=metrics + ) + procs[i] = pivot + nest = test_viz.nest_procs(procs) + self.assertEqual(3, len(nest)) + for i in range(0, 3): + self.assertEqual("metric" + str(i + 1), nest[i]["name"]) + self.assertEqual(None, nest[i].get("val")) + self.assertEqual(3, len(nest[0]["children"])) + self.assertEqual(3, len(nest[0]["children"][0]["children"])) + self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"])) + self.assertEqual( + 1, len(nest[0]["children"][0]["children"][0]["children"][0]["children"]) + ) + + def test_get_data_calls_correct_method(self): + test_viz = viz.PartitionViz(Mock(), {}) + df = Mock() + with self.assertRaises(ValueError): + test_viz.get_data(df) + test_viz.levels_for = Mock(return_value=1) + test_viz.nest_values = Mock(return_value=1) + test_viz.form_data["groupby"] = ["groups"] + test_viz.form_data["time_series_option"] = "not_time" + test_viz.get_data(df) + self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[0][1][0]) + test_viz.form_data["time_series_option"] = "agg_sum" + test_viz.get_data(df) + self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[1][1][0]) + test_viz.form_data["time_series_option"] = "agg_mean" + test_viz.get_data(df) + self.assertEqual("agg_mean", test_viz.levels_for.mock_calls[2][1][0]) + test_viz.form_data["time_series_option"] = "point_diff" + test_viz.levels_for_diff = Mock(return_value=1) + test_viz.get_data(df) + self.assertEqual("point_diff", test_viz.levels_for_diff.mock_calls[0][1][0]) + test_viz.form_data["time_series_option"] = "point_percent" + test_viz.get_data(df) + self.assertEqual("point_percent", test_viz.levels_for_diff.mock_calls[1][1][0]) + test_viz.form_data["time_series_option"] = "point_factor" + test_viz.get_data(df) + self.assertEqual("point_factor", test_viz.levels_for_diff.mock_calls[2][1][0]) + test_viz.levels_for_time = Mock(return_value=1) + test_viz.nest_procs = Mock(return_value=1) + test_viz.form_data["time_series_option"] = "adv_anal" + test_viz.get_data(df) + self.assertEqual(1, len(test_viz.levels_for_time.mock_calls)) + self.assertEqual(1, len(test_viz.nest_procs.mock_calls)) + test_viz.form_data["time_series_option"] = "time_series" + test_viz.get_data(df) + self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[3][1][0]) + self.assertEqual(7, len(test_viz.nest_values.mock_calls)) + + +class RoseVisTestCase(Sip38TestCase): + def test_rose_vis_get_data(self): + raw = {} + t1 = pd.Timestamp("2000") + t2 = pd.Timestamp("2002") + t3 = pd.Timestamp("2004") + raw[DTTM_ALIAS] = [t1, t2, t3, t1, t2, t3, t1, t2, t3] + raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] + raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] + raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] + raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + df = pd.DataFrame(raw) + fd = {"metrics": ["metric1"], "groupby": ["groupA"]} + test_viz = viz.RoseViz(Mock(), fd) + test_viz.metrics = fd["metrics"] + res = test_viz.get_data(df) + expected = { + 946684800000000000: [ + {"time": t1, "value": 1, "key": ("a1",), "name": ("a1",)}, + {"time": t1, "value": 4, "key": ("b1",), "name": ("b1",)}, + {"time": t1, "value": 7, "key": ("c1",), "name": ("c1",)}, + ], + 1009843200000000000: [ + {"time": t2, "value": 2, "key": ("a1",), "name": ("a1",)}, + {"time": t2, "value": 5, "key": ("b1",), "name": ("b1",)}, + {"time": t2, "value": 8, "key": ("c1",), "name": ("c1",)}, + ], + 1072915200000000000: [ + {"time": t3, "value": 3, "key": ("a1",), "name": ("a1",)}, + {"time": t3, "value": 6, "key": ("b1",), "name": ("b1",)}, + {"time": t3, "value": 9, "key": ("c1",), "name": ("c1",)}, + ], + } + self.assertEqual(expected, res) + + +class TimeSeriesTableVizTestCase(Sip38TestCase): + def test_get_data_metrics(self): + form_data = {"metrics": ["sum__A", "count"], "groupby": []} + datasource = self.get_datasource_mock() + raw = {} + t1 = pd.Timestamp("2000") + t2 = pd.Timestamp("2002") + raw[DTTM_ALIAS] = [t1, t2] + raw["sum__A"] = [15, 20] + raw["count"] = [6, 7] + df = pd.DataFrame(raw) + test_viz = viz.TimeTableViz(datasource, form_data) + data = test_viz.get_data(df) + # Check method correctly transforms data + self.assertEqual(set(["count", "sum__A"]), set(data["columns"])) + time_format = "%Y-%m-%d %H:%M:%S" + expected = { + t1.strftime(time_format): {"sum__A": 15, "count": 6}, + t2.strftime(time_format): {"sum__A": 20, "count": 7}, + } + self.assertEqual(expected, data["records"]) + + def test_get_data_group_by(self): + form_data = {"metrics": ["sum__A"], "groupby": ["groupby1"]} + datasource = self.get_datasource_mock() + raw = {} + t1 = pd.Timestamp("2000") + t2 = pd.Timestamp("2002") + raw[DTTM_ALIAS] = [t1, t1, t1, t2, t2, t2] + raw["sum__A"] = [15, 20, 25, 30, 35, 40] + raw["groupby1"] = ["a1", "a2", "a3", "a1", "a2", "a3"] + df = pd.DataFrame(raw) + test_viz = viz.TimeTableViz(datasource, form_data) + data = test_viz.get_data(df) + # Check method correctly transforms data + self.assertEqual(set(["a1", "a2", "a3"]), set(data["columns"])) + time_format = "%Y-%m-%d %H:%M:%S" + expected = { + t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25}, + t2.strftime(time_format): {"a1": 30, "a2": 35, "a3": 40}, + } + self.assertEqual(expected, data["records"]) + + @patch("superset.viz_sip38.BaseViz.query_obj") + def test_query_obj_throws_metrics_and_groupby(self, super_query_obj): + datasource = self.get_datasource_mock() + form_data = {"groupby": ["a"]} + super_query_obj.return_value = {} + test_viz = viz.TimeTableViz(datasource, form_data) + with self.assertRaises(Exception): + test_viz.query_obj() + form_data["metrics"] = ["x", "y"] + test_viz = viz.TimeTableViz(datasource, form_data) + with self.assertRaises(Exception): + test_viz.query_obj() + + +class BaseDeckGLVizTestCase(Sip38TestCase): + def test_scatterviz_get_metrics(self): + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + + form_data = {} + test_viz_deckgl = viz.DeckScatterViz(datasource, form_data) + test_viz_deckgl.point_radius_fixed = {"type": "metric", "value": "int"} + result = test_viz_deckgl.get_metrics() + assert result == ["int"] + + form_data = {} + test_viz_deckgl = viz.DeckScatterViz(datasource, form_data) + test_viz_deckgl.point_radius_fixed = {} + result = test_viz_deckgl.get_metrics() + assert result is None + + def test_get_js_columns(self): + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + mock_d = {"a": "dummy1", "b": "dummy2", "c": "dummy3"} + test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) + result = test_viz_deckgl.get_js_columns(mock_d) + + assert result == {"color": None} + + def test_get_properties(self): + mock_d = {} + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) + + with self.assertRaises(NotImplementedError) as context: + test_viz_deckgl.get_properties(mock_d) + + self.assertTrue("" in str(context.exception)) + + def test_geojson_query_obj(self): + form_data = load_fixture("deck_geojson_form_data.json") + datasource = self.get_datasource_mock() + test_viz_deckgl = viz.DeckGeoJson(datasource, form_data) + results = test_viz_deckgl.query_obj() + + assert results["metrics"] == [] + assert results["columns"] == ["color", "test_col"] + + def test_parse_coordinates(self): + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + viz_instance = viz.BaseDeckGLViz(datasource, form_data) + + coord = viz_instance.parse_coordinates("1.23, 3.21") + self.assertEqual(coord, (1.23, 3.21)) + + coord = viz_instance.parse_coordinates("1.23 3.21") + self.assertEqual(coord, (1.23, 3.21)) + + self.assertEqual(viz_instance.parse_coordinates(None), None) + + self.assertEqual(viz_instance.parse_coordinates(""), None) + + def test_parse_coordinates_raises(self): + form_data = load_fixture("deck_path_form_data.json") + datasource = self.get_datasource_mock() + test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) + + with self.assertRaises(SpatialException): + test_viz_deckgl.parse_coordinates("NULL") + + with self.assertRaises(SpatialException): + test_viz_deckgl.parse_coordinates("fldkjsalkj,fdlaskjfjadlksj") + + +class TimeSeriesVizTestCase(Sip38TestCase): + def test_timeseries_unicode_data(self): + datasource = self.get_datasource_mock() + form_data = {"groupby": ["name"], "metrics": ["sum__payout"]} + raw = {} + raw["name"] = [ + "Real Madrid C.F.πŸ‡ΊπŸ‡ΈπŸ‡¬πŸ‡§", + "Real Madrid C.F.πŸ‡ΊπŸ‡ΈπŸ‡¬πŸ‡§", + "Real Madrid Basket", + "Real Madrid Basket", + ] + raw["__timestamp"] = [ + "2018-02-20T00:00:00", + "2018-03-09T00:00:00", + "2018-02-20T00:00:00", + "2018-03-09T00:00:00", + ] + raw["sum__payout"] = [2, 2, 4, 4] + df = pd.DataFrame(raw) + + test_viz = viz.NVD3TimeSeriesViz(datasource, form_data) + viz_data = {} + viz_data = test_viz.get_data(df) + expected = [ + { + u"values": [ + {u"y": 4, u"x": u"2018-02-20T00:00:00"}, + {u"y": 4, u"x": u"2018-03-09T00:00:00"}, + ], + u"key": (u"Real Madrid Basket",), + }, + { + u"values": [ + {u"y": 2, u"x": u"2018-02-20T00:00:00"}, + {u"y": 2, u"x": u"2018-03-09T00:00:00"}, + ], + u"key": (u"Real Madrid C.F.\U0001f1fa\U0001f1f8\U0001f1ec\U0001f1e7",), + }, + ] + self.assertEqual(expected, viz_data) + + def test_process_data_resample(self): + datasource = self.get_datasource_mock() + + df = pd.DataFrame( + { + "__timestamp": pd.to_datetime( + ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] + ), + "y": [1.0, 2.0, 5.0, 7.0], + } + ) + + self.assertEqual( + viz.NVD3TimeSeriesViz( + datasource, + {"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"}, + ) + .process_data(df)["y"] + .tolist(), + [1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0], + ) + + np.testing.assert_equal( + viz.NVD3TimeSeriesViz( + datasource, + {"metrics": ["y"], "resample_method": "asfreq", "resample_rule": "1D"}, + ) + .process_data(df)["y"] + .tolist(), + [1.0, 2.0, np.nan, np.nan, 5.0, np.nan, 7.0], + ) + + def test_apply_rolling(self): + datasource = self.get_datasource_mock() + df = pd.DataFrame( + index=pd.to_datetime( + ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] + ), + data={"y": [1.0, 2.0, 3.0, 4.0]}, + ) + self.assertEqual( + viz.BigNumberViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "cumsum", + "rolling_periods": 0, + "min_periods": 0, + }, + ) + .apply_rolling(df)["y"] + .tolist(), + [1.0, 3.0, 6.0, 10.0], + ) + self.assertEqual( + viz.BigNumberViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "sum", + "rolling_periods": 2, + "min_periods": 0, + }, + ) + .apply_rolling(df)["y"] + .tolist(), + [1.0, 3.0, 5.0, 7.0], + ) + self.assertEqual( + viz.BigNumberViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "mean", + "rolling_periods": 10, + "min_periods": 0, + }, + ) + .apply_rolling(df)["y"] + .tolist(), + [1.0, 1.5, 2.0, 2.5], + ) + + +class BigNumberVizTestCase(Sip38TestCase): + def test_get_data(self): + datasource = self.get_datasource_mock() + df = pd.DataFrame( + data={ + DTTM_ALIAS: pd.to_datetime( + ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] + ), + "y": [1.0, 2.0, 3.0, 4.0], + } + ) + data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) + self.assertEqual(data[2], {DTTM_ALIAS: pd.Timestamp("2019-01-05"), "y": 3}) + + def test_get_data_with_none(self): + datasource = self.get_datasource_mock() + df = pd.DataFrame( + data={ + DTTM_ALIAS: pd.to_datetime( + ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] + ), + "y": [1.0, 2.0, None, 4.0], + } + ) + data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) + assert np.isnan(data[2]["y"]) From 202103de73189cb3b95abff5f3e53b502e4a5a56 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 7 Apr 2020 14:21:44 +0300 Subject: [PATCH 09/14] Reslove test errors --- tests/core_tests.py | 7 + tests/viz_tests_sip38.py | 1175 -------------------------------------- 2 files changed, 7 insertions(+), 1175 deletions(-) delete mode 100644 tests/viz_tests_sip38.py diff --git a/tests/core_tests.py b/tests/core_tests.py index ce1668e5f525..9ffeb1d6236f 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -119,7 +119,11 @@ def _get_query_context(self) -> Dict[str, Any]: "queries": [ { "granularity": "ds", +<<<<<<< HEAD "columns": ["name"], +======= + "groupby": ["name"], +>>>>>>> Reslove test errors "metrics": [{"label": "sum__num"}], "filters": [], "row_limit": 100, @@ -714,6 +718,7 @@ def test_templated_sql_json(self): data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00") +<<<<<<< HEAD @mock.patch("tests.superset_test_custom_template_processors.datetime") def test_custom_process_template(self, mock_dt) -> None: """Test macro defined in custom template processor works.""" @@ -797,6 +802,8 @@ def test_custom_templated_sql_json(self, sql_lab_mock, mock_dt) -> None: self.delete_fake_presto_db() +======= +>>>>>>> Reslove test errors def test_fetch_datasource_metadata(self): self.login(username="admin") url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table" diff --git a/tests/viz_tests_sip38.py b/tests/viz_tests_sip38.py deleted file mode 100644 index 4f1e5bbf719b..000000000000 --- a/tests/viz_tests_sip38.py +++ /dev/null @@ -1,1175 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file -from datetime import datetime -import logging -from math import nan -from unittest.mock import Mock, patch - -import numpy as np -import pandas as pd - -import tests.test_app -import superset.viz_sip38 as viz -from superset import app -from superset.constants import NULL_STRING -from superset.exceptions import SpatialException -from superset.utils.core import DTTM_ALIAS - -from .base_tests import SupersetTestCase -from .utils import load_fixture - -logger = logging.getLogger(__name__) - - -@patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", - {"SIP_38_VIZ_REARCHITECTURE": True}, - clear=True, -) -class Sip38TestCase(SupersetTestCase): - pass - - -class BaseVizTestCase(Sip38TestCase): - def test_constructor_exception_no_datasource(self): - form_data = {} - datasource = None - with self.assertRaises(Exception): - viz.BaseViz(datasource, form_data) - - def test_process_metrics(self): - # test TableViz metrics in correct order - form_data = { - "url_params": {}, - "row_limit": 500, - "metric": "sum__SP_POP_TOTL", - "entity": "country_code", - "secondary_metric": "sum__SP_POP_TOTL", - "granularity_sqla": "year", - "page_length": 0, - "all_columns": [], - "viz_type": "table", - "since": "2014-01-01", - "until": "2014-01-02", - "metrics": ["sum__SP_POP_TOTL", "SUM(SE_PRM_NENR_MA)", "SUM(SP_URB_TOTL)"], - "country_fieldtype": "cca3", - "percent_metrics": ["count"], - "slice_id": 74, - "time_grain_sqla": None, - "order_by_cols": [], - "groupby": ["country_name"], - "compare_lag": "10", - "limit": "25", - "datasource": "2__table", - "table_timestamp_format": "%Y-%m-%d %H:%M:%S", - "markup_type": "markdown", - "where": "", - "compare_suffix": "o10Y", - } - datasource = Mock() - datasource.type = "table" - test_viz = viz.BaseViz(datasource, form_data) - expect_metric_labels = [ - u"sum__SP_POP_TOTL", - u"SUM(SE_PRM_NENR_MA)", - u"SUM(SP_URB_TOTL)", - u"count", - ] - self.assertEqual(test_viz.metric_labels, expect_metric_labels) - self.assertEqual(test_viz.all_metrics, expect_metric_labels) - - def test_get_df_returns_empty_df(self): - form_data = {"dummy": 123} - query_obj = {"granularity": "day"} - datasource = self.get_datasource_mock() - test_viz = viz.BaseViz(datasource, form_data) - result = test_viz.get_df(query_obj) - self.assertEqual(type(result), pd.DataFrame) - self.assertTrue(result.empty) - - def test_get_df_handles_dttm_col(self): - form_data = {"dummy": 123} - query_obj = {"granularity": "day"} - results = Mock() - results.query = Mock() - results.status = Mock() - results.error_message = Mock() - datasource = Mock() - datasource.type = "table" - datasource.query = Mock(return_value=results) - mock_dttm_col = Mock() - datasource.get_column = Mock(return_value=mock_dttm_col) - - test_viz = viz.BaseViz(datasource, form_data) - test_viz.df_metrics_to_num = Mock() - test_viz.get_fillna_for_columns = Mock(return_value=0) - - results.df = pd.DataFrame(data={DTTM_ALIAS: ["1960-01-01 05:00:00"]}) - datasource.offset = 0 - mock_dttm_col = Mock() - datasource.get_column = Mock(return_value=mock_dttm_col) - mock_dttm_col.python_date_format = "epoch_ms" - result = test_viz.get_df(query_obj) - import logging - - logger.info(result) - pd.testing.assert_series_equal( - result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 5, 0)], name=DTTM_ALIAS) - ) - - mock_dttm_col.python_date_format = None - result = test_viz.get_df(query_obj) - pd.testing.assert_series_equal( - result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 5, 0)], name=DTTM_ALIAS) - ) - - datasource.offset = 1 - result = test_viz.get_df(query_obj) - pd.testing.assert_series_equal( - result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 6, 0)], name=DTTM_ALIAS) - ) - - datasource.offset = 0 - results.df = pd.DataFrame(data={DTTM_ALIAS: ["1960-01-01"]}) - mock_dttm_col.python_date_format = "%Y-%m-%d" - result = test_viz.get_df(query_obj) - pd.testing.assert_series_equal( - result[DTTM_ALIAS], pd.Series([datetime(1960, 1, 1, 0, 0)], name=DTTM_ALIAS) - ) - - def test_cache_timeout(self): - datasource = self.get_datasource_mock() - datasource.cache_timeout = 0 - test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual(0, test_viz.cache_timeout) - - datasource.cache_timeout = 156 - test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual(156, test_viz.cache_timeout) - - datasource.cache_timeout = None - datasource.database.cache_timeout = 0 - self.assertEqual(0, test_viz.cache_timeout) - - datasource.database.cache_timeout = 1666 - self.assertEqual(1666, test_viz.cache_timeout) - - datasource.database.cache_timeout = None - test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout) - - -class TableVizTestCase(Sip38TestCase): - def test_get_data_applies_percentage(self): - form_data = { - "groupby": ["groupA", "groupB"], - "metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - }, - "count", - "avg__C", - ], - "percent_metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - }, - "avg__B", - ], - } - datasource = self.get_datasource_mock() - - df = pd.DataFrame( - { - "SUM(value1)": [15, 20, 25, 40], - "avg__B": [10, 20, 5, 15], - "avg__C": [11, 22, 33, 44], - "count": [6, 7, 8, 9], - "groupA": ["A", "B", "C", "C"], - "groupB": ["x", "x", "y", "z"], - } - ) - - test_viz = viz.TableViz(datasource, form_data) - data = test_viz.get_data(df) - # Check method correctly transforms data and computes percents - self.assertEqual( - [ - "groupA", - "groupB", - "SUM(value1)", - "count", - "avg__C", - "%SUM(value1)", - "%avg__B", - ], - list(data["columns"]), - ) - expected = [ - { - "groupA": "A", - "groupB": "x", - "SUM(value1)": 15, - "count": 6, - "avg__C": 11, - "%SUM(value1)": 0.15, - "%avg__B": 0.2, - }, - { - "groupA": "B", - "groupB": "x", - "SUM(value1)": 20, - "count": 7, - "avg__C": 22, - "%SUM(value1)": 0.2, - "%avg__B": 0.4, - }, - { - "groupA": "C", - "groupB": "y", - "SUM(value1)": 25, - "count": 8, - "avg__C": 33, - "%SUM(value1)": 0.25, - "%avg__B": 0.1, - }, - { - "groupA": "C", - "groupB": "z", - "SUM(value1)": 40, - "count": 9, - "avg__C": 44, - "%SUM(value1)": 0.4, - "%avg__B": 0.3, - }, - ] - self.assertEqual(expected, data["records"]) - - def test_parse_adhoc_filters(self): - form_data = { - "metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - } - ], - "adhoc_filters": [ - { - "expressionType": "SIMPLE", - "clause": "WHERE", - "subject": "value2", - "operator": ">", - "comparator": "100", - }, - { - "expressionType": "SIMPLE", - "clause": "HAVING", - "subject": "SUM(value1)", - "operator": "<", - "comparator": "10", - }, - { - "expressionType": "SQL", - "clause": "HAVING", - "sqlExpression": "SUM(value1) > 5", - }, - { - "expressionType": "SQL", - "clause": "WHERE", - "sqlExpression": "value3 in ('North America')", - }, - ], - } - datasource = self.get_datasource_mock() - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual( - [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"] - ) - self.assertEqual( - [{"op": "<", "val": "10", "col": "SUM(value1)"}], - query_obj["extras"]["having_druid"], - ) - self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"]) - self.assertEqual("(SUM(value1) > 5)", query_obj["extras"]["having"]) - - def test_adhoc_filters_overwrite_legacy_filters(self): - form_data = { - "metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - } - ], - "adhoc_filters": [ - { - "expressionType": "SIMPLE", - "clause": "WHERE", - "subject": "value2", - "operator": ">", - "comparator": "100", - }, - { - "expressionType": "SQL", - "clause": "WHERE", - "sqlExpression": "value3 in ('North America')", - }, - ], - "having": "SUM(value1) > 5", - } - datasource = self.get_datasource_mock() - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual( - [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"] - ) - self.assertEqual([], query_obj["extras"]["having_druid"]) - self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"]) - self.assertEqual("", query_obj["extras"]["having"]) - - @patch("superset.viz_sip38.BaseViz.query_obj") - def test_query_obj_merges_percent_metrics(self, super_query_obj): - datasource = self.get_datasource_mock() - form_data = { - "percent_metrics": ["sum__A", "avg__B", "max__Y"], - "metrics": ["sum__A", "count", "avg__C"], - } - test_viz = viz.TableViz(datasource, form_data) - f_query_obj = {"metrics": form_data["metrics"]} - super_query_obj.return_value = f_query_obj - query_obj = test_viz.query_obj() - self.assertEqual( - ["sum__A", "count", "avg__C", "avg__B", "max__Y"], query_obj["metrics"] - ) - - @patch("superset.viz_sip38.BaseViz.query_obj") - def test_query_obj_throws_columns_and_metrics(self, super_query_obj): - datasource = self.get_datasource_mock() - form_data = {"all_columns": ["A", "B"], "metrics": ["x", "y"]} - super_query_obj.return_value = {} - test_viz = viz.TableViz(datasource, form_data) - with self.assertRaises(Exception): - test_viz.query_obj() - del form_data["metrics"] - form_data["groupby"] = ["B", "C"] - test_viz = viz.TableViz(datasource, form_data) - with self.assertRaises(Exception): - test_viz.query_obj() - - def test_query_obj_merges_all_columns(self): - datasource = self.get_datasource_mock() - form_data = { - "all_columns": ["colA", "colB", "colC"], - "order_by_cols": ['["colA", "colB"]', '["colC"]'], - } - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual(form_data["all_columns"], query_obj["columns"]) - self.assertEqual([["colA", "colB"], ["colC"]], query_obj["orderby"]) - - @patch("superset.viz_sip38.BaseViz.query_obj") - def test_query_obj_uses_sortby(self, super_query_obj): - datasource = self.get_datasource_mock() - form_data = {"timeseries_limit_metric": "__time__", "order_desc": False} - super_query_obj.return_value = {"metrics": ["colA", "colB"]} - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual(["colA", "colB", "__time__"], query_obj["metrics"]) - self.assertEqual([("__time__", True)], query_obj["orderby"]) - - def test_should_be_timeseries_raises_when_no_granularity(self): - datasource = self.get_datasource_mock() - form_data = {"include_time": True} - test_viz = viz.TableViz(datasource, form_data) - with self.assertRaises(Exception): - test_viz.should_be_timeseries() - - def test_adhoc_metric_with_sortby(self): - metrics = [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "sum_value", - "column": {"column_name": "value1", "type": "DOUBLE"}, - } - ] - form_data = { - "metrics": metrics, - "timeseries_limit_metric": { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - }, - "order_desc": False, - } - - df = pd.DataFrame({"SUM(value1)": [15], "sum_value": [15]}) - datasource = self.get_datasource_mock() - test_viz = viz.TableViz(datasource, form_data) - data = test_viz.get_data(df) - self.assertEqual(["sum_value", "SUM(value1)"], data["columns"]) - - -class DistBarVizTestCase(Sip38TestCase): - def test_groupby_nulls(self): - form_data = { - "metrics": ["votes"], - "adhoc_filters": [], - "groupby": ["toppings"], - "columns": [], - } - datasource = self.get_datasource_mock() - df = pd.DataFrame( - { - "toppings": ["cheese", "pepperoni", "anchovies", None], - "votes": [3, 5, 1, 2], - } - ) - test_viz = viz.DistributionBarViz(datasource, form_data) - data = test_viz.get_data(df)[0] - self.assertEqual("votes", data["key"]) - expected_values = [ - {"x": "pepperoni", "y": 5}, - {"x": "cheese", "y": 3}, - {"x": NULL_STRING, "y": 2}, - {"x": "anchovies", "y": 1}, - ] - self.assertEqual(expected_values, data["values"]) - - def test_groupby_nans(self): - form_data = { - "metrics": ["count"], - "adhoc_filters": [], - "groupby": ["beds"], - "columns": [], - } - datasource = self.get_datasource_mock() - df = pd.DataFrame({"beds": [0, 1, nan, 2], "count": [30, 42, 3, 29]}) - test_viz = viz.DistributionBarViz(datasource, form_data) - data = test_viz.get_data(df)[0] - self.assertEqual("count", data["key"]) - expected_values = [ - {"x": "1.0", "y": 42}, - {"x": "0.0", "y": 30}, - {"x": "2.0", "y": 29}, - {"x": NULL_STRING, "y": 3}, - ] - self.assertEqual(expected_values, data["values"]) - - def test_column_nulls(self): - form_data = { - "metrics": ["votes"], - "adhoc_filters": [], - "groupby": ["toppings"], - "columns": ["role"], - } - datasource = self.get_datasource_mock() - df = pd.DataFrame( - { - "toppings": ["cheese", "pepperoni", "cheese", "pepperoni"], - "role": ["engineer", "engineer", None, None], - "votes": [3, 5, 1, 2], - } - ) - test_viz = viz.DistributionBarViz(datasource, form_data) - data = test_viz.get_data(df) - expected = [ - { - "key": NULL_STRING, - "values": [{"x": "pepperoni", "y": 2}, {"x": "cheese", "y": 1}], - }, - { - "key": "engineer", - "values": [{"x": "pepperoni", "y": 5}, {"x": "cheese", "y": 3}], - }, - ] - self.assertEqual(expected, data) - - -class PairedTTestTestCase(Sip38TestCase): - def test_get_data_transforms_dataframe(self): - form_data = { - "groupby": ["groupA", "groupB", "groupC"], - "metrics": ["metric1", "metric2", "metric3"], - } - datasource = self.get_datasource_mock() - # Test data - raw = {} - raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] - raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] - df = pd.DataFrame(raw) - pairedTTestViz = viz.viz_types["paired_ttest"](datasource, form_data) - data = pairedTTestViz.get_data(df) - # Check method correctly transforms data - expected = { - "metric1": [ - { - "values": [ - {"x": 100, "y": 1}, - {"x": 200, "y": 2}, - {"x": 300, "y": 3}, - ], - "group": ("a1", "a2", "a3"), - }, - { - "values": [ - {"x": 100, "y": 4}, - {"x": 200, "y": 5}, - {"x": 300, "y": 6}, - ], - "group": ("b1", "b2", "b3"), - }, - { - "values": [ - {"x": 100, "y": 7}, - {"x": 200, "y": 8}, - {"x": 300, "y": 9}, - ], - "group": ("c1", "c2", "c3"), - }, - ], - "metric2": [ - { - "values": [ - {"x": 100, "y": 10}, - {"x": 200, "y": 20}, - {"x": 300, "y": 30}, - ], - "group": ("a1", "a2", "a3"), - }, - { - "values": [ - {"x": 100, "y": 40}, - {"x": 200, "y": 50}, - {"x": 300, "y": 60}, - ], - "group": ("b1", "b2", "b3"), - }, - { - "values": [ - {"x": 100, "y": 70}, - {"x": 200, "y": 80}, - {"x": 300, "y": 90}, - ], - "group": ("c1", "c2", "c3"), - }, - ], - "metric3": [ - { - "values": [ - {"x": 100, "y": 100}, - {"x": 200, "y": 200}, - {"x": 300, "y": 300}, - ], - "group": ("a1", "a2", "a3"), - }, - { - "values": [ - {"x": 100, "y": 400}, - {"x": 200, "y": 500}, - {"x": 300, "y": 600}, - ], - "group": ("b1", "b2", "b3"), - }, - { - "values": [ - {"x": 100, "y": 700}, - {"x": 200, "y": 800}, - {"x": 300, "y": 900}, - ], - "group": ("c1", "c2", "c3"), - }, - ], - } - self.assertEqual(data, expected) - - def test_get_data_empty_null_keys(self): - form_data = {"groupby": [], "metrics": ["", None]} - datasource = self.get_datasource_mock() - # Test data - raw = {} - raw[DTTM_ALIAS] = [100, 200, 300] - raw[""] = [1, 2, 3] - raw[None] = [10, 20, 30] - - df = pd.DataFrame(raw) - pairedTTestViz = viz.viz_types["paired_ttest"](datasource, form_data) - data = pairedTTestViz.get_data(df) - # Check method correctly transforms data - expected = { - "N/A": [ - { - "values": [ - {"x": 100, "y": 1}, - {"x": 200, "y": 2}, - {"x": 300, "y": 3}, - ], - "group": "All", - } - ], - "NULL": [ - { - "values": [ - {"x": 100, "y": 10}, - {"x": 200, "y": 20}, - {"x": 300, "y": 30}, - ], - "group": "All", - } - ], - } - self.assertEqual(data, expected) - - -class PartitionVizTestCase(Sip38TestCase): - @patch("superset.viz_sip38.BaseViz.query_obj") - def test_query_obj_time_series_option(self, super_query_obj): - datasource = self.get_datasource_mock() - form_data = {} - test_viz = viz.PartitionViz(datasource, form_data) - super_query_obj.return_value = {} - query_obj = test_viz.query_obj() - self.assertFalse(query_obj["is_timeseries"]) - test_viz.form_data["time_series_option"] = "agg_sum" - query_obj = test_viz.query_obj() - self.assertTrue(query_obj["is_timeseries"]) - - def test_levels_for_computes_levels(self): - raw = {} - raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] - raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] - df = pd.DataFrame(raw) - groups = ["groupA", "groupB", "groupC"] - time_op = "agg_sum" - test_viz = viz.PartitionViz(Mock(), {}) - levels = test_viz.levels_for(time_op, groups, df) - self.assertEqual(4, len(levels)) - expected = {DTTM_ALIAS: 1800, "metric1": 45, "metric2": 450, "metric3": 4500} - self.assertEqual(expected, levels[0].to_dict()) - expected = { - DTTM_ALIAS: {"a1": 600, "b1": 600, "c1": 600}, - "metric1": {"a1": 6, "b1": 15, "c1": 24}, - "metric2": {"a1": 60, "b1": 150, "c1": 240}, - "metric3": {"a1": 600, "b1": 1500, "c1": 2400}, - } - self.assertEqual(expected, levels[1].to_dict()) - self.assertEqual(["groupA", "groupB"], levels[2].index.names) - self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) - time_op = "agg_mean" - levels = test_viz.levels_for(time_op, groups, df) - self.assertEqual(4, len(levels)) - expected = { - DTTM_ALIAS: 200.0, - "metric1": 5.0, - "metric2": 50.0, - "metric3": 500.0, - } - self.assertEqual(expected, levels[0].to_dict()) - expected = { - DTTM_ALIAS: {"a1": 200, "c1": 200, "b1": 200}, - "metric1": {"a1": 2, "b1": 5, "c1": 8}, - "metric2": {"a1": 20, "b1": 50, "c1": 80}, - "metric3": {"a1": 200, "b1": 500, "c1": 800}, - } - self.assertEqual(expected, levels[1].to_dict()) - self.assertEqual(["groupA", "groupB"], levels[2].index.names) - self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) - - def test_levels_for_diff_computes_difference(self): - raw = {} - raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] - raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] - df = pd.DataFrame(raw) - groups = ["groupA", "groupB", "groupC"] - test_viz = viz.PartitionViz(Mock(), {}) - time_op = "point_diff" - levels = test_viz.levels_for_diff(time_op, groups, df) - expected = {"metric1": 6, "metric2": 60, "metric3": 600} - self.assertEqual(expected, levels[0].to_dict()) - expected = { - "metric1": {"a1": 2, "b1": 2, "c1": 2}, - "metric2": {"a1": 20, "b1": 20, "c1": 20}, - "metric3": {"a1": 200, "b1": 200, "c1": 200}, - } - self.assertEqual(expected, levels[1].to_dict()) - self.assertEqual(4, len(levels)) - self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) - - def test_levels_for_time_calls_process_data_and_drops_cols(self): - raw = {} - raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] - raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] - df = pd.DataFrame(raw) - groups = ["groupA", "groupB", "groupC"] - test_viz = viz.PartitionViz(Mock(), {"groupby": groups}) - - def return_args(df_drop, aggregate): - return df_drop - - test_viz.process_data = Mock(side_effect=return_args) - levels = test_viz.levels_for_time(groups, df) - self.assertEqual(4, len(levels)) - cols = [DTTM_ALIAS, "metric1", "metric2", "metric3"] - self.assertEqual(sorted(cols), sorted(levels[0].columns.tolist())) - cols += ["groupA"] - self.assertEqual(sorted(cols), sorted(levels[1].columns.tolist())) - cols += ["groupB"] - self.assertEqual(sorted(cols), sorted(levels[2].columns.tolist())) - cols += ["groupC"] - self.assertEqual(sorted(cols), sorted(levels[3].columns.tolist())) - self.assertEqual(4, len(test_viz.process_data.mock_calls)) - - def test_nest_values_returns_hierarchy(self): - raw = {} - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] - raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] - df = pd.DataFrame(raw) - test_viz = viz.PartitionViz(Mock(), {}) - groups = ["groupA", "groupB", "groupC"] - levels = test_viz.levels_for("agg_sum", groups, df) - nest = test_viz.nest_values(levels) - self.assertEqual(3, len(nest)) - for i in range(0, 3): - self.assertEqual("metric" + str(i + 1), nest[i]["name"]) - self.assertEqual(3, len(nest[0]["children"])) - self.assertEqual(1, len(nest[0]["children"][0]["children"])) - self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"])) - - def test_nest_procs_returns_hierarchy(self): - raw = {} - raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300] - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - raw["metric2"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] - raw["metric3"] = [100, 200, 300, 400, 500, 600, 700, 800, 900] - df = pd.DataFrame(raw) - test_viz = viz.PartitionViz(Mock(), {}) - groups = ["groupA", "groupB", "groupC"] - metrics = ["metric1", "metric2", "metric3"] - procs = {} - for i in range(0, 4): - df_drop = df.drop(groups[i:], 1) - pivot = df_drop.pivot_table( - index=DTTM_ALIAS, columns=groups[:i], values=metrics - ) - procs[i] = pivot - nest = test_viz.nest_procs(procs) - self.assertEqual(3, len(nest)) - for i in range(0, 3): - self.assertEqual("metric" + str(i + 1), nest[i]["name"]) - self.assertEqual(None, nest[i].get("val")) - self.assertEqual(3, len(nest[0]["children"])) - self.assertEqual(3, len(nest[0]["children"][0]["children"])) - self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"])) - self.assertEqual( - 1, len(nest[0]["children"][0]["children"][0]["children"][0]["children"]) - ) - - def test_get_data_calls_correct_method(self): - test_viz = viz.PartitionViz(Mock(), {}) - df = Mock() - with self.assertRaises(ValueError): - test_viz.get_data(df) - test_viz.levels_for = Mock(return_value=1) - test_viz.nest_values = Mock(return_value=1) - test_viz.form_data["groupby"] = ["groups"] - test_viz.form_data["time_series_option"] = "not_time" - test_viz.get_data(df) - self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[0][1][0]) - test_viz.form_data["time_series_option"] = "agg_sum" - test_viz.get_data(df) - self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[1][1][0]) - test_viz.form_data["time_series_option"] = "agg_mean" - test_viz.get_data(df) - self.assertEqual("agg_mean", test_viz.levels_for.mock_calls[2][1][0]) - test_viz.form_data["time_series_option"] = "point_diff" - test_viz.levels_for_diff = Mock(return_value=1) - test_viz.get_data(df) - self.assertEqual("point_diff", test_viz.levels_for_diff.mock_calls[0][1][0]) - test_viz.form_data["time_series_option"] = "point_percent" - test_viz.get_data(df) - self.assertEqual("point_percent", test_viz.levels_for_diff.mock_calls[1][1][0]) - test_viz.form_data["time_series_option"] = "point_factor" - test_viz.get_data(df) - self.assertEqual("point_factor", test_viz.levels_for_diff.mock_calls[2][1][0]) - test_viz.levels_for_time = Mock(return_value=1) - test_viz.nest_procs = Mock(return_value=1) - test_viz.form_data["time_series_option"] = "adv_anal" - test_viz.get_data(df) - self.assertEqual(1, len(test_viz.levels_for_time.mock_calls)) - self.assertEqual(1, len(test_viz.nest_procs.mock_calls)) - test_viz.form_data["time_series_option"] = "time_series" - test_viz.get_data(df) - self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[3][1][0]) - self.assertEqual(7, len(test_viz.nest_values.mock_calls)) - - -class RoseVisTestCase(Sip38TestCase): - def test_rose_vis_get_data(self): - raw = {} - t1 = pd.Timestamp("2000") - t2 = pd.Timestamp("2002") - t3 = pd.Timestamp("2004") - raw[DTTM_ALIAS] = [t1, t2, t3, t1, t2, t3, t1, t2, t3] - raw["groupA"] = ["a1", "a1", "a1", "b1", "b1", "b1", "c1", "c1", "c1"] - raw["groupB"] = ["a2", "a2", "a2", "b2", "b2", "b2", "c2", "c2", "c2"] - raw["groupC"] = ["a3", "a3", "a3", "b3", "b3", "b3", "c3", "c3", "c3"] - raw["metric1"] = [1, 2, 3, 4, 5, 6, 7, 8, 9] - df = pd.DataFrame(raw) - fd = {"metrics": ["metric1"], "groupby": ["groupA"]} - test_viz = viz.RoseViz(Mock(), fd) - test_viz.metrics = fd["metrics"] - res = test_viz.get_data(df) - expected = { - 946684800000000000: [ - {"time": t1, "value": 1, "key": ("a1",), "name": ("a1",)}, - {"time": t1, "value": 4, "key": ("b1",), "name": ("b1",)}, - {"time": t1, "value": 7, "key": ("c1",), "name": ("c1",)}, - ], - 1009843200000000000: [ - {"time": t2, "value": 2, "key": ("a1",), "name": ("a1",)}, - {"time": t2, "value": 5, "key": ("b1",), "name": ("b1",)}, - {"time": t2, "value": 8, "key": ("c1",), "name": ("c1",)}, - ], - 1072915200000000000: [ - {"time": t3, "value": 3, "key": ("a1",), "name": ("a1",)}, - {"time": t3, "value": 6, "key": ("b1",), "name": ("b1",)}, - {"time": t3, "value": 9, "key": ("c1",), "name": ("c1",)}, - ], - } - self.assertEqual(expected, res) - - -class TimeSeriesTableVizTestCase(Sip38TestCase): - def test_get_data_metrics(self): - form_data = {"metrics": ["sum__A", "count"], "groupby": []} - datasource = self.get_datasource_mock() - raw = {} - t1 = pd.Timestamp("2000") - t2 = pd.Timestamp("2002") - raw[DTTM_ALIAS] = [t1, t2] - raw["sum__A"] = [15, 20] - raw["count"] = [6, 7] - df = pd.DataFrame(raw) - test_viz = viz.TimeTableViz(datasource, form_data) - data = test_viz.get_data(df) - # Check method correctly transforms data - self.assertEqual(set(["count", "sum__A"]), set(data["columns"])) - time_format = "%Y-%m-%d %H:%M:%S" - expected = { - t1.strftime(time_format): {"sum__A": 15, "count": 6}, - t2.strftime(time_format): {"sum__A": 20, "count": 7}, - } - self.assertEqual(expected, data["records"]) - - def test_get_data_group_by(self): - form_data = {"metrics": ["sum__A"], "groupby": ["groupby1"]} - datasource = self.get_datasource_mock() - raw = {} - t1 = pd.Timestamp("2000") - t2 = pd.Timestamp("2002") - raw[DTTM_ALIAS] = [t1, t1, t1, t2, t2, t2] - raw["sum__A"] = [15, 20, 25, 30, 35, 40] - raw["groupby1"] = ["a1", "a2", "a3", "a1", "a2", "a3"] - df = pd.DataFrame(raw) - test_viz = viz.TimeTableViz(datasource, form_data) - data = test_viz.get_data(df) - # Check method correctly transforms data - self.assertEqual(set(["a1", "a2", "a3"]), set(data["columns"])) - time_format = "%Y-%m-%d %H:%M:%S" - expected = { - t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25}, - t2.strftime(time_format): {"a1": 30, "a2": 35, "a3": 40}, - } - self.assertEqual(expected, data["records"]) - - @patch("superset.viz_sip38.BaseViz.query_obj") - def test_query_obj_throws_metrics_and_groupby(self, super_query_obj): - datasource = self.get_datasource_mock() - form_data = {"groupby": ["a"]} - super_query_obj.return_value = {} - test_viz = viz.TimeTableViz(datasource, form_data) - with self.assertRaises(Exception): - test_viz.query_obj() - form_data["metrics"] = ["x", "y"] - test_viz = viz.TimeTableViz(datasource, form_data) - with self.assertRaises(Exception): - test_viz.query_obj() - - -class BaseDeckGLVizTestCase(Sip38TestCase): - def test_scatterviz_get_metrics(self): - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - - form_data = {} - test_viz_deckgl = viz.DeckScatterViz(datasource, form_data) - test_viz_deckgl.point_radius_fixed = {"type": "metric", "value": "int"} - result = test_viz_deckgl.get_metrics() - assert result == ["int"] - - form_data = {} - test_viz_deckgl = viz.DeckScatterViz(datasource, form_data) - test_viz_deckgl.point_radius_fixed = {} - result = test_viz_deckgl.get_metrics() - assert result is None - - def test_get_js_columns(self): - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - mock_d = {"a": "dummy1", "b": "dummy2", "c": "dummy3"} - test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) - result = test_viz_deckgl.get_js_columns(mock_d) - - assert result == {"color": None} - - def test_get_properties(self): - mock_d = {} - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) - - with self.assertRaises(NotImplementedError) as context: - test_viz_deckgl.get_properties(mock_d) - - self.assertTrue("" in str(context.exception)) - - def test_geojson_query_obj(self): - form_data = load_fixture("deck_geojson_form_data.json") - datasource = self.get_datasource_mock() - test_viz_deckgl = viz.DeckGeoJson(datasource, form_data) - results = test_viz_deckgl.query_obj() - - assert results["metrics"] == [] - assert results["columns"] == ["color", "test_col"] - - def test_parse_coordinates(self): - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - viz_instance = viz.BaseDeckGLViz(datasource, form_data) - - coord = viz_instance.parse_coordinates("1.23, 3.21") - self.assertEqual(coord, (1.23, 3.21)) - - coord = viz_instance.parse_coordinates("1.23 3.21") - self.assertEqual(coord, (1.23, 3.21)) - - self.assertEqual(viz_instance.parse_coordinates(None), None) - - self.assertEqual(viz_instance.parse_coordinates(""), None) - - def test_parse_coordinates_raises(self): - form_data = load_fixture("deck_path_form_data.json") - datasource = self.get_datasource_mock() - test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data) - - with self.assertRaises(SpatialException): - test_viz_deckgl.parse_coordinates("NULL") - - with self.assertRaises(SpatialException): - test_viz_deckgl.parse_coordinates("fldkjsalkj,fdlaskjfjadlksj") - - -class TimeSeriesVizTestCase(Sip38TestCase): - def test_timeseries_unicode_data(self): - datasource = self.get_datasource_mock() - form_data = {"groupby": ["name"], "metrics": ["sum__payout"]} - raw = {} - raw["name"] = [ - "Real Madrid C.F.πŸ‡ΊπŸ‡ΈπŸ‡¬πŸ‡§", - "Real Madrid C.F.πŸ‡ΊπŸ‡ΈπŸ‡¬πŸ‡§", - "Real Madrid Basket", - "Real Madrid Basket", - ] - raw["__timestamp"] = [ - "2018-02-20T00:00:00", - "2018-03-09T00:00:00", - "2018-02-20T00:00:00", - "2018-03-09T00:00:00", - ] - raw["sum__payout"] = [2, 2, 4, 4] - df = pd.DataFrame(raw) - - test_viz = viz.NVD3TimeSeriesViz(datasource, form_data) - viz_data = {} - viz_data = test_viz.get_data(df) - expected = [ - { - u"values": [ - {u"y": 4, u"x": u"2018-02-20T00:00:00"}, - {u"y": 4, u"x": u"2018-03-09T00:00:00"}, - ], - u"key": (u"Real Madrid Basket",), - }, - { - u"values": [ - {u"y": 2, u"x": u"2018-02-20T00:00:00"}, - {u"y": 2, u"x": u"2018-03-09T00:00:00"}, - ], - u"key": (u"Real Madrid C.F.\U0001f1fa\U0001f1f8\U0001f1ec\U0001f1e7",), - }, - ] - self.assertEqual(expected, viz_data) - - def test_process_data_resample(self): - datasource = self.get_datasource_mock() - - df = pd.DataFrame( - { - "__timestamp": pd.to_datetime( - ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] - ), - "y": [1.0, 2.0, 5.0, 7.0], - } - ) - - self.assertEqual( - viz.NVD3TimeSeriesViz( - datasource, - {"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"}, - ) - .process_data(df)["y"] - .tolist(), - [1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0], - ) - - np.testing.assert_equal( - viz.NVD3TimeSeriesViz( - datasource, - {"metrics": ["y"], "resample_method": "asfreq", "resample_rule": "1D"}, - ) - .process_data(df)["y"] - .tolist(), - [1.0, 2.0, np.nan, np.nan, 5.0, np.nan, 7.0], - ) - - def test_apply_rolling(self): - datasource = self.get_datasource_mock() - df = pd.DataFrame( - index=pd.to_datetime( - ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] - ), - data={"y": [1.0, 2.0, 3.0, 4.0]}, - ) - self.assertEqual( - viz.BigNumberViz( - datasource, - { - "metrics": ["y"], - "rolling_type": "cumsum", - "rolling_periods": 0, - "min_periods": 0, - }, - ) - .apply_rolling(df)["y"] - .tolist(), - [1.0, 3.0, 6.0, 10.0], - ) - self.assertEqual( - viz.BigNumberViz( - datasource, - { - "metrics": ["y"], - "rolling_type": "sum", - "rolling_periods": 2, - "min_periods": 0, - }, - ) - .apply_rolling(df)["y"] - .tolist(), - [1.0, 3.0, 5.0, 7.0], - ) - self.assertEqual( - viz.BigNumberViz( - datasource, - { - "metrics": ["y"], - "rolling_type": "mean", - "rolling_periods": 10, - "min_periods": 0, - }, - ) - .apply_rolling(df)["y"] - .tolist(), - [1.0, 1.5, 2.0, 2.5], - ) - - -class BigNumberVizTestCase(Sip38TestCase): - def test_get_data(self): - datasource = self.get_datasource_mock() - df = pd.DataFrame( - data={ - DTTM_ALIAS: pd.to_datetime( - ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] - ), - "y": [1.0, 2.0, 3.0, 4.0], - } - ) - data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) - self.assertEqual(data[2], {DTTM_ALIAS: pd.Timestamp("2019-01-05"), "y": 3}) - - def test_get_data_with_none(self): - datasource = self.get_datasource_mock() - df = pd.DataFrame( - data={ - DTTM_ALIAS: pd.to_datetime( - ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] - ), - "y": [1.0, 2.0, None, 4.0], - } - ) - data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) - assert np.isnan(data[2]["y"]) From 9e14961e544fd951f91b73e863e2e00bf46a26db Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 7 Apr 2020 14:46:55 +0300 Subject: [PATCH 10/14] Add feature flag to QueryContext --- superset/common/query_object.py | 8 +++++++- superset/connectors/sqla/models.py | 24 +++++++++++------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index aac187b1bbb4..aa9197e21de3 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -78,6 +78,7 @@ def __init__( relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"], relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"], ): + is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") self.granularity = granularity self.from_dttm, self.to_dttm = utils.get_since_until( relative_start=relative_start, @@ -89,6 +90,8 @@ def __init__( self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) self.post_processing = post_processing or [] + if not is_sip_38: + self.groupby = groupby or [] # Temporary solution for backward compatibility issue due the new format of # non-ad-hoc metric which needs to adhere to superset-ui per @@ -109,7 +112,7 @@ def __init__( self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={}) self.columns = columns or [] - if groupby: + if is_sip_38 and groupby: self.columns += groupby logger.warning( f"The field groupby is deprecated. Viz plugins should " @@ -134,6 +137,9 @@ def to_dict(self) -> Dict[str, Any]: "columns": self.columns, "orderby": self.orderby, } + if not is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): + query_object_dict["groupby"] = self.groupby + return query_object_dict def cache_key(self, **extra: Any) -> str: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8303d2f61c02..bd7308874b9c 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -65,9 +65,6 @@ logger = logging.getLogger(__name__) -IS_SIP_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") - - class SqlaQuery(NamedTuple): extra_cache_keys: List[Any] labels_expected: List[str] @@ -726,6 +723,7 @@ def get_sqla_query( # sqla "filter": filter, "columns": {col.column_name: col for col in self.columns}, } + is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") template_kwargs.update(self.template_params_dict) extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys @@ -753,10 +751,10 @@ def get_sqla_query( # sqla ) ) if ( - IS_SIP_38 + is_sip_38 and not metrics and not columns - or not IS_SIP_38 + or not is_sip_38 and not groupby and not metrics and not columns @@ -779,9 +777,9 @@ def get_sqla_query( # sqla select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() - if IS_SIP_38 and metrics and columns or not IS_SIP_38 and groupby: + if is_sip_38 and metrics and columns or not is_sip_38 and groupby: # dedup columns while preserving order - if IS_SIP_38: + if is_sip_38: groupby = list(dict.fromkeys(columns)) else: groupby = list(dict.fromkeys(groupby)) @@ -843,7 +841,7 @@ def get_sqla_query( # sqla tbl = self.get_from_clause(template_processor) - if IS_SIP_38 and metrics or not IS_SIP_38 and not columns: + if is_sip_38 and metrics or not is_sip_38 and not columns: qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] @@ -907,10 +905,10 @@ def get_sqla_query( # sqla qry = qry.having(and_(*having_clause_and)) if ( - IS_SIP_38 + is_sip_38 and not orderby and metrics - or not IS_SIP_38 + or not is_sip_38 and not orderby and not columns ): @@ -936,12 +934,12 @@ def get_sqla_query( # sqla qry = qry.limit(row_limit) if ( - IS_SIP_38 + is_sip_38 and is_timeseries and timeseries_limit and columns and not time_groupby_inline - or not IS_SIP_38 + or not is_sip_38 and is_timeseries and timeseries_limit and groupby @@ -1014,7 +1012,7 @@ def get_sqla_query( # sqla "columns": columns, "order_desc": True, } - if not IS_SIP_38: + if not is_sip_38: prequery_obj["groupby"] = groupby result = self.query(prequery_obj) From a99c7ba9cf6fba72661f57163e974162d5dfb6b9 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 8 Apr 2020 16:37:50 +0300 Subject: [PATCH 11/14] Resolve tests and bad rebase --- superset/models/slice.py | 4 ++-- superset/views/utils.py | 4 ++-- superset/viz_sip38.py | 4 ++-- tests/core_tests.py | 7 ------- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/superset/models/slice.py b/superset/models/slice.py index ad0ae5dfbfa4..59544407c8bf 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -33,9 +33,9 @@ from superset.utils import core as utils if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): - from superset.viz_sip38 import BaseViz, viz_types + from superset.viz_sip38 import BaseViz, viz_types # type: ignore else: - from superset.viz import BaseViz, viz_types + from superset.viz import BaseViz, viz_types # type: ignore if TYPE_CHECKING: # pylint: disable=unused-import diff --git a/superset/views/utils.py b/superset/views/utils.py index 780dd99af005..edb2987ab429 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -32,9 +32,9 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): - from superset import viz_sip38 as viz + from superset import viz_sip38 as viz # type: ignore else: - from superset import viz + from superset import viz # type: ignore FORM_DATA_KEY_BLACKLIST: List[str] = [] diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index 87950d66e367..1639dcd04e9b 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -1879,8 +1879,8 @@ def get_data(self, df: pd.DataFrame) -> VizData: for row in d: country = None if isinstance(row["country"], str): - country = countries.get(fd.get("country_fieldtype"), row["country"]) - + if "country_fieldtype" in fd: + country = countries.get(fd["country_fieldtype"], row["country"]) if country: row["country"] = country["cca3"] row["latitude"] = country["lat"] diff --git a/tests/core_tests.py b/tests/core_tests.py index 9ffeb1d6236f..73711d9301d3 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -119,11 +119,7 @@ def _get_query_context(self) -> Dict[str, Any]: "queries": [ { "granularity": "ds", -<<<<<<< HEAD - "columns": ["name"], -======= "groupby": ["name"], ->>>>>>> Reslove test errors "metrics": [{"label": "sum__num"}], "filters": [], "row_limit": 100, @@ -718,7 +714,6 @@ def test_templated_sql_json(self): data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00") -<<<<<<< HEAD @mock.patch("tests.superset_test_custom_template_processors.datetime") def test_custom_process_template(self, mock_dt) -> None: """Test macro defined in custom template processor works.""" @@ -802,8 +797,6 @@ def test_custom_templated_sql_json(self, sql_lab_mock, mock_dt) -> None: self.delete_fake_presto_db() -======= ->>>>>>> Reslove test errors def test_fetch_datasource_metadata(self): self.login(username="admin") url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table" From 49a7f69adadf0be325e3dcb2b0a5d945446e3bbe Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 10 Apr 2020 13:25:19 +0300 Subject: [PATCH 12/14] Backport recent changes from viz.py and fix broken DeckGL charts --- superset/viz_sip38.py | 58 ++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index 1639dcd04e9b..cc47cbc677cc 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -371,7 +371,7 @@ def query_obj(self) -> Dict[str, Any]: "druid_time_origin": form_data.get("druid_time_origin", ""), "having": form_data.get("having", ""), "having_druid": form_data.get("having_filters", []), - "time_grain_sqla": form_data.get("time_grain_sqla", ""), + "time_grain_sqla": form_data.get("time_grain_sqla"), "time_range_endpoints": form_data.get("time_range_endpoints"), "where": form_data.get("where", ""), } @@ -472,10 +472,10 @@ def get_df_payload(self, query_obj=None, **kwargs): self.status = utils.QueryStatus.SUCCESS is_loaded = True stats_logger.incr("loaded_from_cache") - except Exception as e: - logger.exception(e) + except Exception as ex: + logger.exception(ex) logger.error( - "Error reading cache: " + utils.error_msg_from_exception(e) + "Error reading cache: " + utils.error_msg_from_exception(ex) ) logger.info("Serving from cache") @@ -487,10 +487,10 @@ def get_df_payload(self, query_obj=None, **kwargs): if not self.force: stats_logger.incr("loaded_from_source_without_force") is_loaded = True - except Exception as e: - logger.exception(e) + except Exception as ex: + logger.exception(ex) if not self.error_message: - self.error_message = "{}".format(e) + self.error_message = "{}".format(ex) self.status = utils.QueryStatus.FAILED stacktrace = utils.get_stacktrace() @@ -510,12 +510,13 @@ def get_df_payload(self, query_obj=None, **kwargs): stats_logger.incr("set_cache_key") cache.set(cache_key, cache_value, timeout=self.cache_timeout) - except Exception as e: + except Exception as ex: # cache.set call can fail if the backend is down or if # the key is too large or whatever other reasons logger.warning("Could not cache key {}".format(cache_key)) - logger.exception(e) + logger.exception(ex) cache.delete(cache_key) + return { "cache_key": self._any_cache_key, "cached_dttm": self._any_cached_dttm, @@ -525,6 +526,8 @@ def get_df_payload(self, query_obj=None, **kwargs): "form_data": self.form_data, "is_cached": self._any_cache_key is not None, "query": self.query, + "from_dttm": self.from_dttm if hasattr(self, "from_dttm") else None, + "to_dttm": self.to_dttm if hasattr(self, "to_dttm") else None, "status": self.status, "stacktrace": stacktrace, "rowcount": len(df.index) if df is not None else 0, @@ -662,17 +665,18 @@ def get_data(self, df: pd.DataFrame) -> VizData: self.form_data.get("percent_metrics") or [] ) - df = pd.concat( - [ - df[non_percent_metric_columns], - ( - df[percent_metric_columns] - .div(df[percent_metric_columns].sum()) - .add_prefix("%") - ), - ], - axis=1, - ) + if not df.empty: + df = pd.concat( + [ + df[non_percent_metric_columns], + ( + df[percent_metric_columns] + .div(df[percent_metric_columns].sum()) + .add_prefix("%") + ), + ], + axis=1, + ) data = self.handle_js_int_overflow( dict(records=df.to_dict(orient="records"), columns=list(df.columns)) @@ -2281,6 +2285,20 @@ def add_null_filters(self): filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""}) fd["adhoc_filters"].append(filter_) + def query_obj(self): + fd = self.form_data + + # add NULL filters + if fd.get("filter_nulls", True): + self.add_null_filters() + + d = super().query_obj() + + metrics = self.get_metrics() + if metrics: + d["metrics"] = metrics + return d + def get_js_columns(self, d): cols = self.form_data.get("js_columns") or [] return {col: d.get(col) for col in cols} From e6f28f9ebd3441d0a6197e3b00513ab6830effba Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 13 Apr 2020 09:27:37 +0300 Subject: [PATCH 13/14] Fix bad rebase --- superset/common/query_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index aa9197e21de3..63158c6751bb 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -24,7 +24,7 @@ from flask_babel import gettext as _ from pandas import DataFrame -from superset import app +from superset import app, is_feature_enabled from superset.exceptions import QueryObjectValidationError from superset.utils import core as utils, pandas_postprocessing from superset.views.utils import get_time_range_endpoints From 0c6845e7d1cc5e7943f5f481cabc0da6bfe83f60 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Tue, 14 Apr 2020 09:48:00 +0300 Subject: [PATCH 14/14] backport #9522 and address comments --- superset/connectors/druid/models.py | 30 ++++++++++-------------- superset/connectors/sqla/models.py | 36 +++++++---------------------- superset/viz_sip38.py | 6 +++-- 3 files changed, 24 insertions(+), 48 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index db1d674c8794..a4cc52748a68 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1190,10 +1190,9 @@ def run_query( # druid # the dimensions list with dimensionSpecs expanded - if IS_SIP_38: - dimensions = self.get_dimensions(columns, columns_dict) - else: - dimensions = self.get_dimensions(groupby, columns_dict) + dimensions = self.get_dimensions( + columns if IS_SIP_38 else groupby, columns_dict + ) extras = extras or {} qry = dict( @@ -1220,12 +1219,8 @@ def run_query( # druid order_direction = "descending" if order_desc else "ascending" - if ( - IS_SIP_38 - and not metrics - and "__time" not in columns - or not IS_SIP_38 - and columns + if (IS_SIP_38 and not metrics and "__time" not in columns) or ( + not IS_SIP_38 and columns ): columns.append("__time") del qry["post_aggregations"] @@ -1236,12 +1231,8 @@ def run_query( # druid qry["granularity"] = "all" qry["limit"] = row_limit client.scan(**qry) - elif ( - IS_SIP_38 - and columns - or not IS_SIP_38 - and len(groupby) == 0 - and not having_filters + elif (IS_SIP_38 and columns) or ( + not IS_SIP_38 and len(groupby) == 0 and not having_filters ): logger.info("Running timeseries query for no groupby values") del qry["dimensions"] @@ -1249,7 +1240,10 @@ def run_query( # druid elif ( not having_filters and order_desc - and (IS_SIP_38 and len(columns) == 1 or not IS_SIP_38 and len(groupby) == 1) + and ( + (IS_SIP_38 and len(columns) == 1) + or (not IS_SIP_38 and len(groupby) == 1) + ) ): dim = list(qry["dimensions"])[0] logger.info("Running two-phase topn query for dimension [{}]".format(dim)) @@ -1303,7 +1297,7 @@ def run_query( # druid logger.info("Phase 2 Complete") elif ( having_filters - or (IS_SIP_38 and columns or not IS_SIP_38 and len(groupby)) > 0 + or ((IS_SIP_38 and columns) or (not IS_SIP_38 and len(groupby))) > 0 ): # If grouping on multiple fields or using a having filter # we have to force a groupby query diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index bd7308874b9c..c290363dd8c5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -751,13 +751,9 @@ def get_sqla_query( # sqla ) ) if ( - is_sip_38 - and not metrics - and not columns - or not is_sip_38 - and not groupby - and not metrics + not metrics and not columns + and (is_sip_38 or (not is_sip_38 and not groupby)) ): raise Exception(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] @@ -777,12 +773,9 @@ def get_sqla_query( # sqla select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() - if is_sip_38 and metrics and columns or not is_sip_38 and groupby: + if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby): # dedup columns while preserving order - if is_sip_38: - groupby = list(dict.fromkeys(columns)) - else: - groupby = list(dict.fromkeys(groupby)) + groupby = list(dict.fromkeys(columns if is_sip_38 else groupby)) select_exprs = [] for s in groupby: @@ -841,7 +834,7 @@ def get_sqla_query( # sqla tbl = self.get_from_clause(template_processor) - if is_sip_38 and metrics or not is_sip_38 and not columns: + if (is_sip_38 and metrics) or (not is_sip_38 and not columns): qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] @@ -904,14 +897,7 @@ def get_sqla_query( # sqla qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) - if ( - is_sip_38 - and not orderby - and metrics - or not is_sip_38 - and not orderby - and not columns - ): + if not orderby and ((is_sip_38 and metrics) or (not is_sip_38 and not columns)): orderby = [(main_metric_expr, not order_desc)] # To ensure correct handling of the ORDER BY labeling we need to reference the @@ -934,16 +920,10 @@ def get_sqla_query( # sqla qry = qry.limit(row_limit) if ( - is_sip_38 - and is_timeseries - and timeseries_limit - and columns - and not time_groupby_inline - or not is_sip_38 - and is_timeseries + is_timeseries and timeseries_limit - and groupby and not time_groupby_inline + and ((is_sip_38 and columns) or (not is_sip_38 and groupby)) ): if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index cc47cbc677cc..1992bd1c14bb 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -159,6 +159,8 @@ def __init__( self.results: Optional[QueryResult] = None self.error_message: Optional[str] = None self.force = force + self.from_ddtm: Optional[datetime] = None + self.to_dttm: Optional[datetime] = None # Keeping track of whether some data came from cache # this is useful to trigger the when @@ -526,8 +528,8 @@ def get_df_payload(self, query_obj=None, **kwargs): "form_data": self.form_data, "is_cached": self._any_cache_key is not None, "query": self.query, - "from_dttm": self.from_dttm if hasattr(self, "from_dttm") else None, - "to_dttm": self.to_dttm if hasattr(self, "to_dttm") else None, + "from_dttm": self.from_dttm, + "to_dttm": self.to_dttm, "status": self.status, "stacktrace": stacktrace, "rowcount": len(df.index) if df is not None else 0,