From 683ddfc19f825c5aa8d1598302346c8e1a459e08 Mon Sep 17 00:00:00 2001 From: Jeremy Zucker Date: Wed, 22 Nov 2023 12:50:58 -0500 Subject: [PATCH] Update effect_handlers.ipynb ```python print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, torch.tensor(8.5))) ``` prevents the following error message: ```python --------------------------------------------------------------------------- ValueError Traceback (most recent call last) File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:196, in Trace.log_prob_sum(self, site_filter) 195 try: --> 196 log_p = site["fn"].log_prob( 197 site["value"], *site["args"], **site["kwargs"] 198 ) 199 except ValueError as e: File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value) 78 if self._validate_args: ---> 79 self._validate_sample(value) 80 # compute the variance File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/distribution.py:271, in Distribution._validate_sample(self, value) 270 if not isinstance(value, torch.Tensor): --> 271 raise ValueError('The value argument to log_prob must be a Tensor') 273 event_dim_start = len(value.size()) - len(self._event_shape) ValueError: The value argument to log_prob must be a Tensor The above exception was the direct cause of the following exception: ValueError Traceback (most recent call last) Cell In[5], line 9 6 return _log_joint 8 scale_log_joint = make_log_joint(scale) ----> 9 print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5)) Cell In[5], line 5, in make_log_joint.._log_joint(cond_data, *args, **kwargs) 3 conditioned_model = poutine.condition(model, data=cond_data) 4 trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs) ----> 5 return trace.log_prob_sum() File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:202, in Trace.log_prob_sum(self, site_filter) 200 _, exc_value, traceback = sys.exc_info() 201 shapes = self.format_shapes(last_site=site["name"]) --> 202 raise ValueError( 203 "Error while computing log_prob_sum at site '{}':\n{}\n{}\n".format( 204 name, exc_value, shapes 205 ) 206 ).with_traceback(traceback) from e 207 log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum() 208 site["log_prob_sum"] = log_p File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:196, in Trace.log_prob_sum(self, site_filter) 194 else: 195 try: --> 196 log_p = site["fn"].log_prob( 197 site["value"], *site["args"], **site["kwargs"] 198 ) 199 except ValueError as e: 200 _, exc_value, traceback = sys.exc_info() File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value) 77 def log_prob(self, value): 78 if self._validate_args: ---> 79 self._validate_sample(value) 80 # compute the variance 81 var = (self.scale ** 2) File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/distribution.py:271, in Distribution._validate_sample(self, value) 257 """ 258 Argument validation for distribution methods such as `log_prob`, 259 `cdf` and `icdf`. The rightmost dimensions of a value to be (...) 268 distribution's batch and event shapes. 269 """ 270 if not isinstance(value, torch.Tensor): --> 271 raise ValueError('The value argument to log_prob must be a Tensor') 273 event_dim_start = len(value.size()) - len(self._event_shape) 274 if value.size()[event_dim_start:] != self._event_shape: ValueError: Error while computing log_prob_sum at site 'weight': The value argument to log_prob must be a Tensor Trace Shapes: Param Sites: Sample Sites: weight dist | value | ``` --- tutorial/source/effect_handlers.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial/source/effect_handlers.ipynb b/tutorial/source/effect_handlers.ipynb index f680516b1d..e81a82ccee 100644 --- a/tutorial/source/effect_handlers.ipynb +++ b/tutorial/source/effect_handlers.ipynb @@ -99,7 +99,7 @@ " return _log_joint\n", "\n", "scale_log_joint = make_log_joint(scale)\n", - "print(scale_log_joint({\"measurement\": 9.5, \"weight\": 8.23}, 8.5))" + "print(scale_log_joint({\"measurement\": torch.tensor(9.5), \"weight\": torch.tensor(8.23)}, torch.tensor(8.5)))" ] }, {