Skip to content

Commit

Permalink
feat: add optional prophet forecasting functionality to chart data api (
Browse files Browse the repository at this point in the history
apache#10324)

* feat: add prophet post processing operation

* add tests

* lint

* whitespace

* remove whitespace

* address comments

* add note to UPDATING.md
  • Loading branch information
villebro authored and auxten committed Nov 20, 2020
1 parent 7d0c8cb commit bff4346
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 23 deletions.
2 changes: 2 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ assists people when migrating to a new version.

## Next

* [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`.

* [10320](https://github.com/apache/incubator-superset/pull/10320): References to blacklst/whitelist language have been replaced with more appropriate alternatives. All configs refencing containing `WHITE`/`BLACK` have been replaced with `ALLOW`/`DENY`. Affected config variables that need to be updated: `TIME_GRAIN_BLACKLIST`, `VIZ_TYPE_BLACKLIST`, `DRUID_DATA_SOURCE_BLACKLIST`.

* [9964](https://github.com/apache/incubator-superset/pull/9964): Breaking change on Flask-AppBuilder 3. If you're using OAuth, find out what needs to be changed [here](https://github.com/dpgaspar/Flask-AppBuilder/blob/master/README.rst#change-log).
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_git_sha():
"cockroachdb": ["cockroachdb==0.3.3"],
"thumbnails": ["Pillow>=7.0.0, <8.0.0"],
"excel": ["xlrd>=1.2.0, <1.3"],
"prophet": ["fbprophet>=0.6, <0.7"],
},
python_requires="~=3.6",
author="Apache Software Foundation",
Expand Down
98 changes: 78 additions & 20 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@
}


TIME_GRAINS = (
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-28T00:00:00Z/P1W", # Week starting Sunday
"1969-12-29T00:00:00Z/P1W", # Week starting Monday
"P1W/1970-01-03T00:00:00Z", # Week ending Saturday
"P1W/1970-01-04T00:00:00Z", # Week ending Sunday
)


class ChartPostSchema(Schema):
"""
Schema to add a new chart.
Expand Down Expand Up @@ -423,6 +443,62 @@ class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptions
)


class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Prophet operation config.
"""

time_grain = fields.String(
description="Time grain used to specify time period increments in prediction. "
"Supports [ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
"durations.",
validate=validate.OneOf(choices=TIME_GRAINS),
example="P1D",
required=True,
)
periods = fields.Integer(
descrption="Time periods (in units of `time_grain`) to predict into the future",
min=1,
example=7,
required=True,
)
confidence_interval = fields.Float(
description="Width of predicted confidence interval",
validate=[
Range(
min=0,
max=1,
min_inclusive=False,
max_inclusive=False,
error=_("`confidence_interval` must be between 0 and 1 (exclusive)"),
)
],
example=0.8,
required=True,
)
yearly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should yearly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
weekly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should weekly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
monthly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should monthly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)


class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Pivot operation config.
Expand Down Expand Up @@ -534,6 +610,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
"geohash_decode",
"geohash_encode",
"pivot",
"prophet",
"rolling",
"select",
"sort",
Expand Down Expand Up @@ -613,26 +690,7 @@ class ChartDataExtrasSchema(Schema):
description="To what level of granularity should the temporal column be "
"aggregated. Supports "
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.",
validate=validate.OneOf(
choices=(
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-28T00:00:00Z/P1W", # Week starting Sunday
"1969-12-29T00:00:00Z/P1W", # Week starting Monday
"P1W/1970-01-03T00:00:00Z", # Week ending Saturday
"P1W/1970-01-04T00:00:00Z", # Week ending Sunday
),
),
validate=validate.OneOf(choices=TIME_GRAINS),
example="P1D",
allow_none=True,
)
Expand Down
142 changes: 142 additions & 0 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@
"cumsum",
)

PROPHET_TIME_GRAIN_MAP = {
"PT1S": "S",
"PT1M": "min",
"PT5M": "5min",
"PT10M": "10min",
"PT15M": "15min",
"PT0.5H": "30min",
"PT1H": "H",
"P1D": "D",
"P1W": "W",
"P1M": "M",
"P0.25Y": "Q",
"P1Y": "A",
"1969-12-28T00:00:00Z/P1W": "W",
"1969-12-29T00:00:00Z/P1W": "W",
"P1W/1970-01-03T00:00:00Z": "W",
"P1W/1970-01-04T00:00:00Z": "W",
}


def _flatten_column_after_pivot(
column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
Expand Down Expand Up @@ -544,3 +563,126 @@ def contribution(
if temporal_series is not None:
contribution_df.insert(0, DTTM_ALIAS, temporal_series)
return contribution_df


def _prophet_parse_seasonality(
input_value: Optional[Union[bool, int]]
) -> Union[bool, str, int]:
if input_value is None:
return "auto"
if isinstance(input_value, bool):
return input_value
try:
return int(input_value)
except ValueError:
return input_value


def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
df: DataFrame,
confidence_interval: float,
yearly_seasonality: Union[bool, str, int],
weekly_seasonality: Union[bool, str, int],
daily_seasonality: Union[bool, str, int],
periods: int,
freq: str,
) -> DataFrame:
"""
Fit a prophet model and return a DataFrame with predicted results.
"""
try:
from fbprophet import Prophet # pylint: disable=import-error
except ModuleNotFoundError:
raise QueryObjectValidationError(_("`fbprophet` package not installed"))
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
weekly_seasonality=weekly_seasonality,
daily_seasonality=daily_seasonality,
)
model.fit(df)
future = model.make_future_dataframe(periods=periods, freq=freq)
forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]]
return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"])


def prophet( # pylint: disable=too-many-arguments
df: DataFrame,
time_grain: str,
periods: int,
confidence_interval: float,
yearly_seasonality: Optional[Union[bool, int]] = None,
weekly_seasonality: Optional[Union[bool, int]] = None,
daily_seasonality: Optional[Union[bool, int]] = None,
) -> DataFrame:
"""
Add forecasts to each series in a timeseries dataframe, along with confidence
intervals for the prediction. For each series, the operation creates three
new columns with the column name suffixed with the following values:
- `__yhat`: the forecast for the given date
- `__yhat_lower`: the lower bound of the forecast for the given date
- `__yhat_upper`: the upper bound of the forecast for the given date
- `__yhat_upper`: the upper bound of the forecast for the given date
:param df: DataFrame containing all-numeric data (temporal column ignored)
:param time_grain: Time grain used to specify time period increments in prediction
:param periods: Time periods (in units of `time_grain`) to predict into the future
:param confidence_interval: Width of predicted confidence interval
:param yearly_seasonality: Should yearly seasonality be applied.
An integer value will specify Fourier order of seasonality.
:param weekly_seasonality: Should weekly seasonality be applied.
An integer value will specify Fourier order of seasonality, `None` will
automatically detect seasonality.
:param daily_seasonality: Should daily seasonality be applied.
An integer value will specify Fourier order of seasonality, `None` will
automatically detect seasonality.
:return: DataFrame with contributions, with temporal column at beginning if present
"""
# validate inputs
if not time_grain:
raise QueryObjectValidationError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
raise QueryObjectValidationError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not periods or periods < 0 or not isinstance(periods, int):
raise QueryObjectValidationError(_("Periods must be a positive integer value"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if DTTM_ALIAS not in df.columns:
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
if len(df.columns) < 2:
raise QueryObjectValidationError(_("DataFrame include at least one series"))

target_df = DataFrame()
for column in [column for column in df.columns if column != DTTM_ALIAS]:
fit_df = _prophet_fit_and_predict(
df=df[[DTTM_ALIAS, column]].rename(columns={DTTM_ALIAS: "ds", column: "y"}),
confidence_interval=confidence_interval,
yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality),
weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality),
daily_seasonality=_prophet_parse_seasonality(daily_seasonality),
periods=periods,
freq=freq,
)
new_columns = [
f"{column}__yhat",
f"{column}__yhat_lower",
f"{column}__yhat_upper",
f"{column}",
]
fit_df.columns = new_columns
if target_df.empty:
target_df = fit_df
else:
for new_column in new_columns:
target_df = target_df.assign(**{new_column: fit_df[new_column]})
target_df.reset_index(level=0, inplace=True)
return target_df.rename(columns={"ds": DTTM_ALIAS})
39 changes: 38 additions & 1 deletion tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from datetime import datetime
from unittest import mock

import prison
import humanize
import prison
import pytest
from sqlalchemy.sql import func

from tests.test_app import app
Expand Down Expand Up @@ -796,6 +797,42 @@ def test_chart_data_mixed_case_filter_op(self):
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)

def test_chart_data_prophet(self):
"""
Chart data API: Ensure prophet post transformation works
"""
pytest.importorskip("fbprophet")
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
time_grain = "P1Y"
request_payload["queries"][0]["is_timeseries"] = True
request_payload["queries"][0]["groupby"] = []
request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain}
request_payload["queries"][0]["granularity"] = "ds"
request_payload["queries"][0]["post_processing"] = [
{
"operation": "prophet",
"options": {
"time_grain": time_grain,
"periods": 3,
"confidence_interval": 0.9,
},
}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
print(rv.data)
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
row = result["data"][0]
self.assertIn("__timestamp", row)
self.assertIn("sum__num", row)
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 47)

def test_chart_data_no_data(self):
"""
Chart data API: Test chart data with empty result
Expand Down
15 changes: 14 additions & 1 deletion tests/fixtures/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import date
from datetime import date, datetime

from pandas import DataFrame, to_datetime

Expand Down Expand Up @@ -133,3 +133,16 @@
],
}
)

prophet_df = DataFrame(
{
"__timestamp": [
datetime(2018, 12, 31),
datetime(2019, 12, 31),
datetime(2020, 12, 31),
datetime(2021, 12, 31),
],
"a": [1.1, 1, 1.9, 3.15],
"b": [4, 3, 4.1, 3.95],
}
)
Loading

0 comments on commit bff4346

Please sign in to comment.