Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/src/robyn/modeling/ridge/ridge_data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional, Tuple
import pandas as pd
import numpy as np
from scipy.signal import lfilter


class RidgeDataBuilder:
Expand Down Expand Up @@ -166,12 +167,11 @@ def _hyper_collector(
return hyper_collect

def _geometric_adstock(self, x: pd.Series, theta: float) -> pd.Series:
# print(f"Before adstock: {x.head()}")
y = x.copy()
for i in range(1, len(x)):
y.iloc[i] += theta * y.iloc[i - 1]
# print(f"After adstock: {y.head()}")
return y

x_array = x.values
# Use lfilter to efficiently compute the geometric transformation
y = lfilter([1], [1, -theta], x_array)
return pd.Series(y, index=x.index)

def _hill_transformation(
self, x: pd.Series, alpha: float, gamma: float
Expand Down
31 changes: 31 additions & 0 deletions python/tests/unit/modeling/ridge/test_adstock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import numpy as np
import pandas as pd
from scipy.signal import lfilter
from robyn.modeling.ridge.ridge_data_builder import RidgeDataBuilder


def original_geometric_adstock(x: pd.Series, theta: float) -> pd.Series:
y = x.copy()
for i in range(1, len(x)):
y.iloc[i] += theta * y.iloc[i - 1]
return y


@pytest.mark.parametrize("theta", [0, 0.5, 0.8, 1])
def test_geometric_adstock(theta):
x = pd.Series(np.random.rand(10_000)) # Random test data

# Instantiate the RidgeDataBuilder object (without requiring real data)
dummy_data = None
ridge_builder = RidgeDataBuilder(dummy_data, dummy_data)

# Call the actual function from the RidgeDataBuilder instance
optimized = ridge_builder._geometric_adstock(x, theta)

# Compute the expected output using the original function
original = original_geometric_adstock(x, theta)

assert np.allclose(
original, optimized, atol=1e-6
), f"Mismatch found for theta={theta}"