Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions meridian/model/posterior_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class PosteriorMCMCSampler:

def __init__(self, meridian: "model.Meridian"):
self._meridian = meridian
self._joint_dist = None

@property
def model(self) -> "model.Meridian":
Expand Down Expand Up @@ -460,13 +461,15 @@ def joint_dist_unpinned():
return joint_dist_unpinned

def _get_joint_dist(self) -> backend.tfd.Distribution:
mmm = self.model
y = (
backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
if mmm.holdout_id is not None
else mmm.kpi_scaled
)
return self._get_joint_dist_unpinned().experimental_pin(y=y)
if self._joint_dist is None:
mmm = self.model
y = (
backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
if mmm.holdout_id is not None
else mmm.kpi_scaled
)
self._joint_dist = self._get_joint_dist_unpinned().experimental_pin(y=y)
return self._joint_dist

def __call__(
self,
Expand Down Expand Up @@ -560,9 +563,10 @@ def __call__(
traces = []
for n_chains_batch in n_chains_list:
try:
joint_dist = self._get_joint_dist()
mcmc = _xla_windowed_adaptive_nuts(
n_draws=n_burnin + n_keep,
joint_dist=self._get_joint_dist(),
joint_dist=joint_dist,
n_chains=n_chains_batch,
num_adaptation_steps=n_adapt,
current_state=current_state,
Expand Down
49 changes: 49 additions & 0 deletions meridian/model/posterior_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,55 @@ def test_sample_posterior_seed_int(self):
test_utils.assert_allequal(kwargs0["seed"], [123, 123])
test_utils.assert_allequal(kwargs1["seed"], [124, 124])

@parameterized.named_parameters(
dict(testcase_name="n_chains_is_list", n_chains_type=list),
dict(testcase_name="n_chains_is_int", n_chains_type=int),
)
def test_sample_posterior_joint_distribution_cached(self, n_chains_type):
n_chains = (
[self._N_CHAINS, self._N_CHAINS]
if n_chains_type == list
else self._N_CHAINS
)
self.enter_context(
mock.patch.object(
posterior_sampler,
"_xla_windowed_adaptive_nuts",
autospec=True,
return_value=collections.namedtuple(
"StatesAndTrace", ["all_states", "trace"]
)(
all_states=self.test_posterior_states_media_and_rf,
trace=self.test_trace,
),
)
)
mock_get_joint_dist_unpinned = self.enter_context(
mock.patch.object(
posterior_sampler.PosteriorMCMCSampler,
"_get_joint_dist_unpinned",
autospec=True,
return_value=mock.MagicMock(
experimental_pin=lambda y: y,
),
)
)
model_spec = spec.ModelSpec()
input_data = self.short_input_data_with_media_and_rf
meridian = model.Meridian(
input_data=input_data,
model_spec=model_spec,
)

meridian.sample_posterior(
n_chains=n_chains,
n_adapt=self._N_ADAPT,
n_burnin=self._N_BURNIN,
n_keep=self._N_KEEP,
seed=123,
)
mock_get_joint_dist_unpinned.assert_called_once()


if __name__ == "__main__":
absltest.main()