Skip to content
82 changes: 82 additions & 0 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,13 @@ def query_obj(self) -> QueryObjectDict:
raise QueryObjectValidationError(_("Please choose at least one metric"))
if set(groupby) & set(columns):
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
sort_by = self.form_data.get("timeseries_limit_metric")
Copy link
Member

@zhaoyongjie zhaoyongjie Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious, why use timeseries_limit_metric control in a not time-series chart.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just from legacy form_data field names. We'd need a db migration if we want to rename it.

Superset has a timeseries-first UI. Each pivot group can be considered a series since the base query has the ability to aggregate by time + pivot column.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait a minute, since this is a new field for pivot table and apache-superset/superset-ui#952 hasn't been merged, maybe we CAN rename it to something else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can keep it as timeseries_limit_metric to be able to reuse the existing code and do a wholesale migration later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @villebro what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was super confused by the name yesterday, and agree we need to get rid of it. I'm ok both ways, but I think it might be simpler to do a bulk rename later to avoid partially migrated names.

if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d

@staticmethod
Expand Down Expand Up @@ -963,6 +970,18 @@ class TreemapViz(BaseViz):
credits = '<a href="https://d3js.org">d3.js</a>'
is_timeseries = False

def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metrics = self.form_data.get("metrics")
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d

def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
nlevels = df.index.nlevels
if nlevels == 1:
Expand Down Expand Up @@ -1621,6 +1640,18 @@ class NVD3TimeSeriesStackedViz(NVD3TimeSeriesViz):
sort_series = True
pivot_fill_value = 0

def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metrics = self.form_data.get("metrics")
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d


class HistogramViz(BaseViz):

Expand Down Expand Up @@ -2080,6 +2111,13 @@ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
d["groupby"] = [fd.get("series")]
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d

def get_data(self, df: pd.DataFrame) -> VizData:
Expand Down Expand Up @@ -2153,6 +2191,18 @@ class HorizonViz(NVD3TimeSeriesViz):
"d3-horizon-chart</a>"
)

def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metrics = self.form_data.get("metrics")
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d


class MapboxViz(BaseViz):

Expand Down Expand Up @@ -2769,6 +2819,18 @@ class PairedTTestViz(BaseViz):
sort_series = False
is_timeseries = True

def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metrics = self.form_data.get("metrics")
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d

def get_data(self, df: pd.DataFrame) -> VizData:
"""
Transform received data frame into an object of the form:
Expand Down Expand Up @@ -2824,6 +2886,18 @@ class RoseViz(NVD3TimeSeriesViz):
sort_series = False
is_timeseries = True

def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metrics = self.form_data.get("metrics")
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in d["metrics"]:
d["metrics"].append(sort_by)
if self.form_data.get("order_desc"):
d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))]
return d

def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
Expand Down Expand Up @@ -2862,6 +2936,14 @@ def query_obj(self) -> QueryObjectDict:
time_op = self.form_data.get("time_series_option", "not_time")
# Return time series data if the user specifies so
query_obj["is_timeseries"] = time_op != "not_time"
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in query_obj["metrics"]:
query_obj["metrics"].append(sort_by)
query_obj["orderby"] = [
(sort_by, not self.form_data.get("order_desc", True))
]
return query_obj

def levels_for(
Expand Down