Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ python/src/robynpy.egg-info*
python/oldportedcode
python/src/tutorials/mytestenv
*.log
<<<<<<< HEAD
python/src/tutorials/test_modeling.py
python/src/tutorials/data/*
python/src/tutorials/test_modeling.py
*.txt
*.txt
=======
python/src/tutorials/data/R/*
>>>>>>> 8c2e2cbf (add data mapper code)
6 changes: 6 additions & 0 deletions python/src/robyn/data/entities/mmmdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(
context_signs: Optional[List[ContextSigns]] = None,
factor_vars: Optional[List[str]] = None,
all_media: Optional[List[str]] = None,
day_interval: Optional[int] = 7,
interval_type: Optional[str] = "week",
) -> None:
self.dep_var: Optional[str] = dep_var
self.dep_var_type: DependentVarType = dep_var_type
Expand All @@ -61,6 +63,8 @@ def __init__(
self.context_signs: Optional[List[str]] = context_signs
self.factor_vars: Optional[List[str]] = factor_vars
self.all_media = all_media or paid_media_spends
self.day_interval: Optional[int] = day_interval
self.interval_type: Optional[str] = interval_type

def __str__(self) -> str:
return f"""
Expand All @@ -79,6 +83,8 @@ def __str__(self) -> str:
context_signs: {self.context_signs}
factor_vars: {self.factor_vars}
all_media: {self.all_media}
day_interval: {self.day_interval}
interval_type: {self.interval_type}
"""

def update(self, **kwargs: Any) -> None:
Expand Down
28 changes: 28 additions & 0 deletions python/src/robyn/modeling/entities/convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass
from typing import List, Dict, Optional
import pandas as pd


@dataclass
class Convergence:
moo_distrb_plot: Optional[str] # Hexadecimal string of plot image data
moo_cloud_plot: Optional[str] # Hexadecimal string of plot image data
errors: pd.DataFrame
conv_msg: List[str]

@classmethod
def from_dict(cls, data: Dict[str, any]) -> "Convergence":
return cls(
moo_distrb_plot=data.get("moo_distrb_plot"),
moo_cloud_plot=data.get("moo_cloud_plot"),
errors=pd.DataFrame(data.get("errors", [])),
conv_msg=data.get("conv_msg", []),
)

def to_dict(self) -> Dict[str, any]:
return {
"moo_distrb_plot": self.moo_distrb_plot,
"moo_cloud_plot": self.moo_cloud_plot,
"errors": self.errors.to_dict(orient="records"),
"conv_msg": self.conv_msg,
}
5 changes: 3 additions & 2 deletions python/src/robyn/modeling/entities/modeloutputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ class ModelOutputs:
convergence: Dict[str, Any]
select_id: str
seed: int
hyper_bound_ng: Dict[str, Any] # For hyperBoundNG
hyper_bound_fixed: Dict[str, Any] # For hyperBoundFixed
hyper_bound_ng: Dict[str, Any]
hyper_bound_fixed: Dict[str, Any]
ts_validation_plot: Optional[str]
189 changes: 136 additions & 53 deletions python/src/robyn/modeling/feature_engineering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# feature_engineering.py
# pyre-strict
from typing import List, Optional, Dict, Any, Tuple
import pandas as pd
import warnings
import numpy as np
from dataclasses import dataclass
from scipy.optimize import curve_fit
Expand Down Expand Up @@ -42,56 +45,71 @@ def __init__(
self.holidays_data = holidays_data

def perform_feature_engineering(self, quiet: bool = False) -> FeaturizedMMMData:
dt_mod = self._prepare_data()
dt_modRollWind = self._create_rolling_window_data(dt_mod)
dt_transform = self._prepare_data()

if any(var in self.holidays_data.prophet_vars for var in ["trend", "season", "holiday", "monthly", "weekday"]):
dt_transform = self._prophet_decomposition(dt_transform)
if not quiet:
print("Prophet decomposition complete.")

# Include all independent variables
all_ind_vars = (
self.holidays_data.prophet_vars
+ self.mmm_data.mmmdata_spec.context_vars
+ self.mmm_data.mmmdata_spec.paid_media_spends
+ self.mmm_data.mmmdata_spec.organic_vars
)

dt_mod = dt_transform
dt_modRollWind = self._create_rolling_window_data(dt_transform)
media_cost_factor = self._calculate_media_cost_factor(dt_modRollWind)
modNLS = self._run_models(dt_modRollWind, media_cost_factor)

if "trend" in self.holidays_data.prophet_vars:
dt_mod = self._prophet_decomposition(dt_mod)
print("Prophet decomposition complete.")
columns_to_keep = ["ds", "dep_var"] + all_ind_vars
# Only keep columns that exist in both dataframes
columns_to_keep = [col for col in columns_to_keep if col in dt_mod.columns and col in dt_modRollWind.columns]

dt_mod = dt_transform[columns_to_keep]
dt_modRollWind = dt_modRollWind[columns_to_keep]

if not quiet:
print("Feature engineering complete.")

return FeaturizedMMMData(dt_mod=dt_mod, dt_modRollWind=dt_modRollWind, modNLS=modNLS)

def _prepare_data(self) -> pd.DataFrame:
dt_mod = self.mmm_data.data.copy()
dt_mod["ds"] = pd.to_datetime(dt_mod[self.mmm_data.mmmdata_spec.date_var])
dt_mod["dep_var"] = dt_mod[self.mmm_data.mmmdata_spec.dep_var]
return dt_mod
dt_transform = self.mmm_data.data.copy()
dt_transform["ds"] = pd.to_datetime(dt_transform[self.mmm_data.mmmdata_spec.date_var]).dt.strftime("%Y-%m-%d")
dt_transform["dep_var"] = dt_transform[self.mmm_data.mmmdata_spec.dep_var]
dt_transform["competitor_sales_B"] = dt_transform["competitor_sales_B"].astype("int64")
return dt_transform

def _create_rolling_window_data(self, dt_transform: pd.DataFrame) -> pd.DataFrame:
window_start = self.mmm_data.mmmdata_spec.window_start
window_end = self.mmm_data.mmmdata_spec.window_end

if window_start is None and window_end is None:
# If both are None, return the entire DataFrame
return dt_transform
elif window_start is None:
# If only start is None, filter up to end
return dt_transform[dt_transform["ds"] <= window_end]
elif window_end is None:
# If only end is None, filter from start
return dt_transform[dt_transform["ds"] >= window_start]
else:
# If both are provided, filter between start and end
return dt_transform[(dt_transform["ds"] >= window_start) & (dt_transform["ds"] <= window_end)]

def _calculate_media_cost_factor(self, dt_input_roll_wind: pd.DataFrame) -> pd.Series:
total_spend = dt_input_roll_wind[self.mmm_data.mmmdata_spec.paid_media_spends].sum().sum()
return dt_input_roll_wind[self.mmm_data.mmmdata_spec.paid_media_spends].sum() / total_spend

def _run_models(self, dt_modRollWind: pd.DataFrame, media_cost_factor: float) -> Dict[str, Dict[str, Any]]:
modNLS = {}
modNLS = {"results": {}, "yhat": pd.DataFrame(), "plots": {}}

for paid_media_var in self.mmm_data.mmmdata_spec.paid_media_spends:
result = self._fit_spend_exposure(dt_modRollWind, paid_media_var, media_cost_factor)
if result is not None:
modNLS[paid_media_var] = result

# Keep the plot windows open
plt.show()
modNLS["results"][paid_media_var] = result["res"]
modNLS["yhat"] = pd.concat([modNLS["yhat"], result["plot"]], ignore_index=True)
modNLS["plots"][paid_media_var] = result["plot"]

return modNLS

Expand All @@ -111,9 +129,6 @@ def michaelis_menten(x, Vmax, Km):
spend_data = dt_modRollWind[spend_var]
exposure_data = dt_modRollWind[exposure_var]

print(f"spend_data range: {spend_data.min()} - {spend_data.max()}")
print(f"exposure_data range: {exposure_data.min()} - {exposure_data.max()}")

try:
# Fit Michaelis-Menten model
popt_nls, _ = curve_fit(
Expand Down Expand Up @@ -175,56 +190,124 @@ def _hill_function(x, alpha, gamma):

def _prophet_decomposition(self, dt_mod: pd.DataFrame) -> pd.DataFrame:
prophet_vars = self.holidays_data.prophet_vars
print(f"Prophet variables: {prophet_vars}")

if not any(var in prophet_vars for var in ["trend", "season", "holiday", "monthly", "weekday"]):
return dt_mod
recurrence = dt_mod[["ds", "dep_var"]].rename(columns={"dep_var": "y"}).copy()
recurrence["ds"] = pd.to_datetime(recurrence["ds"])

prophet_data = dt_mod[[self.mmm_data.mmmdata_spec.date_var, self.mmm_data.mmmdata_spec.dep_var]].copy()
prophet_data.columns = ["ds", "y"]
holidays = self._set_holidays(
dt_mod, self.holidays_data.dt_holidays.copy(), self.mmm_data.mmmdata_spec.interval_type
)

use_trend = "trend" in prophet_vars
use_holiday = "holiday" in prophet_vars
use_season = "season" in prophet_vars or "yearly.seasonality" in prophet_vars
use_monthly = "monthly" in prophet_vars
use_weekday = "weekday" in prophet_vars or "weekly.seasonality" in prophet_vars

model = Prophet(yearly_seasonality=use_season, weekly_seasonality=use_weekday, daily_seasonality=False)

if use_holiday and self.holidays_data is not None:
holidays_df = self._prepare_holidays_for_prophet(self.holidays_data.dt_holidays)
model.add_country_holidays(country_name=self.holidays_data.prophet_country)
model.holidays = holidays_df
elif use_holiday:
print(
"Warning: Holiday decomposition requested but no holiday data provided. Skipping holiday decomposition."
)
dt_regressors = pd.concat(
[
recurrence,
dt_mod[
self.mmm_data.mmmdata_spec.paid_media_spends
+ self.mmm_data.mmmdata_spec.context_vars
+ self.mmm_data.mmmdata_spec.organic_vars
],
],
axis=1,
)
dt_regressors["ds"] = pd.to_datetime(dt_regressors["ds"])

# Handle the case where prophet_country is a string
prophet_country = self.holidays_data.prophet_country
if isinstance(prophet_country, str):
prophet_country = [prophet_country]

prophet_params = {
"holidays": (holidays[holidays["country"].isin(prophet_country)] if use_holiday else None),
"yearly_seasonality": use_season,
"weekly_seasonality": use_weekday,
"daily_seasonality": False,
}

# Add custom parameters (assuming they're stored in self.custom_params)
if hasattr(self, "custom_params"):
if "yearly.seasonality" in self.custom_params:
prophet_params["yearly_seasonality"] = self.custom_params["yearly.seasonality"]
if "weekly.seasonality" in self.custom_params and self.mmm_data.mmmdata_spec.day_interval <= 7:
prophet_params["weekly_seasonality"] = self.custom_params["weekly.seasonality"]
# Add other custom parameters as needed

model = Prophet(**prophet_params)

if use_monthly:
model.add_seasonality(name="monthly", period=30.5, fourier_order=5)

model.fit(prophet_data)
future = model.make_future_dataframe(periods=0)
forecast = model.predict(future)

if self.mmm_data.mmmdata_spec.factor_vars:
dt_ohe = pd.get_dummies(dt_regressors[self.mmm_data.mmmdata_spec.factor_vars], drop_first=False)
ohe_names = [col for col in dt_ohe.columns if col not in self.mmm_data.mmmdata_spec.factor_vars]
for addreg in ohe_names:
model.add_regressor(addreg)
dt_ohe = pd.concat([dt_regressors.drop(columns=self.mmm_data.mmmdata_spec.factor_vars), dt_ohe], axis=1)
mod_ohe = model.fit(dt_ohe)
dt_forecastRegressor = mod_ohe.predict(dt_ohe)
forecastRecurrence = dt_forecastRegressor.drop(
columns=[col for col in dt_forecastRegressor.columns if "_lower" in col or "_upper" in col]
)
for aggreg in self.mmm_data.mmmdata_spec.factor_vars:
oheRegNames = [col for col in forecastRecurrence.columns if col.startswith(f"{aggreg}_")]
get_reg = forecastRecurrence[oheRegNames].sum(axis=1)
dt_mod[aggreg] = (get_reg - get_reg.min()) / (get_reg.max() - get_reg.min())
else:
if self.mmm_data.mmmdata_spec.day_interval == 1:
warnings.warn(
"Currently, there's a known issue with prophet that may crash this use case.\n"
"Read more here: https://github.com/facebookexperimental/Robyn/issues/472"
)
mod = model.fit(dt_regressors)
forecastRecurrence = mod.predict(dt_regressors)

these = range(len(recurrence))
if use_trend:
dt_mod["trend"] = forecast["trend"].values
dt_mod["trend"] = forecastRecurrence["trend"].iloc[these].values
if use_season:
dt_mod["season"] = forecast["yearly"].values
if use_holiday and "holidays" in forecast.columns:
dt_mod["holiday"] = forecast["holidays"].values
if use_weekday:
dt_mod["weekday"] = forecast["weekly"].values
dt_mod["season"] = forecastRecurrence["yearly"].iloc[these].values
if use_monthly:
dt_mod["monthly"] = forecast["monthly"].values
dt_mod["monthly"] = forecastRecurrence["monthly"].iloc[these].values
if use_weekday:
dt_mod["weekday"] = forecastRecurrence["weekly"].iloc[these].values
if use_holiday:
dt_mod["holiday"] = forecastRecurrence["holidays"].iloc[these].values

return dt_mod

def _prepare_holidays_for_prophet(self, holidays_df: pd.DataFrame) -> pd.DataFrame:
# Assuming holidays_df has 'ds' and 'holiday' columns
prepared_holidays = holidays_df[["ds", "holiday"]].copy()
prepared_holidays["ds"] = pd.to_datetime(prepared_holidays["ds"])
return prepared_holidays
def _set_holidays(self, dt_transform: pd.DataFrame, dt_holidays: pd.DataFrame, interval_type: str) -> pd.DataFrame:
# Ensure 'ds' column is datetime
dt_transform["ds"] = pd.to_datetime(dt_transform["ds"])
dt_holidays["ds"] = pd.to_datetime(dt_holidays["ds"])

if interval_type == "day":
return dt_holidays
elif interval_type == "week":
week_start = dt_transform["ds"].dt.weekday[0]
holidays = dt_holidays.copy()
# Adjust to the start of the week
holidays["ds"] = (
holidays["ds"] - pd.to_timedelta(holidays["ds"].dt.weekday, unit="D") + pd.Timedelta(days=week_start)
)
holidays = (
holidays.groupby(["ds", "country", "year"])
.agg(holiday=("holiday", lambda x: ", ".join(x)), n=("holiday", "count"))
.reset_index()
)
return holidays
elif interval_type == "month":
if not all(dt_transform["ds"].dt.day == 1):
raise ValueError("Monthly data should have first day of month as datestamp, e.g.'2020-01-01'")
holidays = dt_holidays.copy()
holidays["ds"] = holidays["ds"].dt.to_period("M").dt.to_timestamp()
holidays = holidays.groupby(["ds", "country", "year"])["holiday"].agg(lambda x: ", ".join(x)).reset_index()
return holidays
else:
raise ValueError("Invalid interval_type. Must be 'day', 'week', or 'month'.")

def _apply_transformations(self, x: pd.Series, params: ChannelHyperparameters) -> pd.Series:
x_adstock = self._apply_adstock(x, params)
Expand Down
7 changes: 3 additions & 4 deletions python/src/robyn/visualization/feature_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ def plot_spend_exposure(self, featurized_data: FeaturizedMMMData, channel: str)
Returns:
plt.Figure: A matplotlib Figure object containing the spend-exposure plot.
"""
if channel not in featurized_data.modNLS:
if channel not in featurized_data.modNLS["results"]:
raise ValueError(f"No spend-exposure data available for channel: {channel}")

model_data = featurized_data.modNLS[channel]
plot_data = model_data["plot"]
res = model_data["res"]
res = featurized_data.modNLS["results"][channel]
plot_data = featurized_data.modNLS["plots"][channel]

fig, ax = plt.subplots(figsize=(10, 6))

Expand Down
2 changes: 1 addition & 1 deletion python/src/tutorials/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys

# Load environment variables from .env file
load_dotenv()
load_dotenv(dotenv_path=".env.sample")
# Retrieve the ROBYN_BASE_PATH environment variable
base_path = os.getenv("ROBYN_BASE_PATH")
if not base_path:
Expand Down
Loading