Skip to content
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4c6d618
Add configs and parser for scalar adjustments
m-d-bowerman Jul 11, 2024
67cb204
Remove comment
m-d-bowerman Jul 11, 2024
f00ec82
ScalarAdjustments docstring update
m-d-bowerman Jul 11, 2024
ae4d1f3
Update to use dataclass and remove classmethod
m-d-bowerman Jul 11, 2024
843c8f2
Add ScalarForecast class
m-d-bowerman Jul 11, 2024
eeb2308
Remove unneeded class
m-d-bowerman Jul 11, 2024
b7b079d
Remove unused imports
m-d-bowerman Jul 11, 2024
cfd7b81
Fillna for scalar columns with 1
m-d-bowerman Jul 11, 2024
da13c85
Fix error, caps for module-level variables
m-d-bowerman Jul 11, 2024
4797042
Fix Ruff errors
m-d-bowerman Jul 11, 2024
ae6d4a6
Update post_init to remove arguments
m-d-bowerman Jul 15, 2024
5d8797d
Reference alias instead of slug
m-d-bowerman Jul 15, 2024
9971b46
Add base data pull class and class to pull forecast data
m-d-bowerman Jul 16, 2024
6d3fb98
Update link to metric hub github.io
m-d-bowerman Jul 16, 2024
b64c21f
Keep metric_hub attribute in BaseForecast
m-d-bowerman Jul 16, 2024
53ab92b
Add scalar model configs and updates
m-d-bowerman Jul 18, 2024
963dd0a
Metric hub query logic updates
m-d-bowerman Jul 18, 2024
c169af2
Remove unneeded break and add comment
m-d-bowerman Jul 18, 2024
436a364
Updates to summarize methods for scalars
m-d-bowerman Jul 18, 2024
ba6c544
Ad clicks scalar model config
m-d-bowerman Jul 18, 2024
df6bac2
Rename config
m-d-bowerman Jul 18, 2024
04d956e
Rename config again
m-d-bowerman Jul 18, 2024
6fde34d
Fix error in inputs
m-d-bowerman Jul 18, 2024
fd27d5a
Fix start date for monthly forecasts
m-d-bowerman Jul 18, 2024
809fa2c
Config update
m-d-bowerman Jul 18, 2024
86dff3f
Remove forecast adjustment methods
m-d-bowerman Jul 26, 2024
4b5b5e6
Add tests
m-d-bowerman Jul 26, 2024
df69c0f
Update tests
m-d-bowerman Jul 26, 2024
3f1ece7
Remove unneeded variable
m-d-bowerman Jul 26, 2024
b75c46a
Fixes error in pulling training data from Prophet model
m-d-bowerman Aug 7, 2024
18dead8
Adjust post_init for testing
m-d-bowerman Aug 14, 2024
c619059
Ruff format changes
m-d-bowerman Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions jobs/kpi-forecasting/kpi_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kpi_forecasting.models.prophet_forecast import ProphetForecast
from kpi_forecasting.models.funnel_forecast import FunnelForecast
from kpi_forecasting.metric_hub import MetricHub

from kpi_forecasting.metric_hub import ForecastDataPull

# A dictionary of available models in the `models` directory.
MODELS = {
Expand All @@ -16,10 +16,15 @@ def main() -> None:
config = YAML(filepath=CLI().args.config).data
model_type = config.forecast_model.model_type

if model_type in MODELS:
metric_hub = MetricHub(**config.metric_hub)
model = MODELS[model_type](metric_hub=metric_hub, **config.forecast_model)
if hasattr(config, "metric_hub"):
data_puller = MetricHub(**config.metric_hub)
elif hasattr(config, "forecast_data_pull"):
data_puller = ForecastDataPull(**config.forecast_data_pull)
else:
raise KeyError("No metric_hub or forecast_data_pull key in config to pull data.")

if model_type in MODELS:
model = MODELS[model_type](data_puller=data_puller, **config.forecast_model)
model.fit()
model.predict()
model.summarize(**config.summarize)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import attr
from typing import List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from dotmap import DotMap
from pathlib import Path
from typing import List, Optional, Union

import pandas as pd

from kpi_forecasting.inputs import YAML


PARENT_PATH = Path(__file__).parent
HOLIDAY_PATH = PARENT_PATH / "holidays.yaml"
REGRESSOR_PATH = PARENT_PATH / "regressors.yaml"
SCALAR_PATH = PARENT_PATH / "scalar_adjustments.yaml"

holiday_collection = YAML(HOLIDAY_PATH)
regressor_collection = YAML(REGRESSOR_PATH)
HOLIDAY_COLLECTION = YAML(HOLIDAY_PATH)
REGRESSOR_COLLECTION = YAML(REGRESSOR_PATH)
SCALAR_ADJUSTMENTS = YAML(SCALAR_PATH)


@attr.s(auto_attribs=True, frozen=False)
Expand All @@ -38,3 +44,84 @@ class ProphetHoliday:
ds: List
lower_window: int
upper_window: int


@dataclass
class ScalarAdjustments:
"""
Holds the names and dates where a scalar adjustment should be applied.

Args:
name (str): The name of the adjustment from the scalar_adjustments.yaml file.
forecast_start_date (datetime): The first forecast_start_date where this iteration of the
adjustment should be applied. This adjustment will apply to any subsequent forecast
until another update of this adjustment is made.
adjustments_dataframe (DataFrame): A DataFrame that contains the dimensions of the segments
being forecasted as columns, as well as the start dates and values for each scalar
adjustment.
"""

name: str
adjustment_dotmap: DotMap

def __post_init__(self):
adj_list = []
self.forecast_start_date = datetime.strptime(
self.adjustment_dotmap.forecast_start_date, "%Y-%m-%d"
)
for segment_dat in self.adjustment_dotmap.segments:
segment = {**segment_dat.segment}
segment_adjustment_dat = [
{**segment, **adj} for adj in segment_dat.adjustments
]
adj_list.append(pd.DataFrame(segment_adjustment_dat))
self.adjustments_dataframe = pd.concat(adj_list, ignore_index=True)


def parse_scalar_adjustments(
metric_hub_slug: str, forecast_start_date: datetime
) -> List[ScalarAdjustments]:
"""
Parses the SCALAR_ADJUSTMENTS to find the applicable scalar adjustments for a given metric hub slug
and forecast start date.

Args:
metric_hub_slug (str): The metric hub slug being forecasted. It must be present by name in the
scalar_adjustments.yaml.
forecast_start_date (str): The first date being forecasted. Used here to map to the correct scalar
adjustments as the adjustments will be updated over time.

Returns:
List[ScalarAdjustments]: A list of ScalarAdjustments, where each ScalarAdjustments is a named scalar adjustment with the
dates that the adjustment should be applied for each segment being modeled.
"""
metric_adjustments = getattr(SCALAR_ADJUSTMENTS.data, metric_hub_slug)
if not metric_adjustments:
raise KeyError(f"No adjustments found for {metric_hub_slug} in {SCALAR_PATH}.")

# Creates a list of ScalarAdjustments objects that apply for this metric and forecast_start_date
applicable_adjustments = []
for named_adjustment in metric_adjustments:
parsed_named_adjustments = [
ScalarAdjustments(named_adjustment.name, adj_dotmap)
for adj_dotmap in named_adjustment.adjustments
]

# Sort list of parsed adjustments by forecast_start_date
sorted_parsed_named_adjustments = sorted(
parsed_named_adjustments, key=lambda d: d.forecast_start_date
)

# Iterate over the sorted list to find any adjustments that apply after the supplied forecast_start_date.
## Returns `None` if no applicable value is found
matched_adjustment = None
for parsed_adjustment in sorted_parsed_named_adjustments:
if forecast_start_date >= parsed_adjustment.forecast_start_date:
matched_adjustment = parsed_adjustment
else:
break

if matched_adjustment:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that matched_adjustment is the most recent element of sorted_parsed_name_adjustments? If so, it looks like max takes a key function.

applicable_adjustments.append(matched_adjustment)

return applicable_adjustments
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
---
search_forecasting_revenue_per_ad_click:
- name: "year_over_year_growth"
description: "Estimate of YoY growth in RPC, from input from stakeholders."
adjustments:
- forecast_start_date: "2024-01-01"
segments:
- segment:
{
partner: "Google",
country: "US",
device: "desktop",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.03
- segment:
{
partner: "Google",
country: "ROW",
device: "desktop",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.04
- segment:
{
partner: "Google",
country: "ROW",
device: "mobile",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.04
- forecast_start_date: "2024-04-01"
segments:
- segment:
{
partner: "Google",
country: "US",
device: "desktop",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.10
- segment:
{
partner: "Google",
country: "ROW",
device: "desktop",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.10
- segment:
{
partner: "Google",
country: "ROW",
device: "mobile",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.04
- forecast_start_date: "2024-05-01"
segments:
- segment:
{
partner: "Google",
country: "US",
device: "desktop",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.10
- start_date: "2024-08-01"
value: 1.03
- segment:
{
partner: "Google",
country: "ROW",
device: "desktop",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.10
- start_date: "2024-08-01"
value: 1.04
- segment:
{
partner: "Google",
country: "ROW",
device: "mobile",
channel: "all",
}
adjustments:
- start_date: "2024-01-01"
value: 1.04
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
metric_hub:
app_name: "multi_product"
slug: "search_forecasting_revenue_per_ad_click"
alias: "search_forecasting_revenue_per_ad_click"
start_date: "2020-01-01"
end_date: "last complete month"
segments:
device: "device"
channel: "'all'"
country: "CASE WHEN country = 'US' THEN 'US' ELSE 'ROW' END"
partner: "partner"
where: "partner = 'Google'"

forecast_model:
model_type: "scalar"
start_date: NULL
end_date: NULL
use_holidays: False
parameters:
formula: "search_forecasting_revenue_per_ad_click:YOY * scalar"

summarize:
requires_summarization: False
periods: ["month"]

write_results:
project: "moz-fx-data-shared-prod"
dataset: "revenue_derived"
table: "search_revenue_forecasts_v1"
components_table: "search_revenue_model_components_v1"
Loading