From 5b0a79665c0fcf0d7543254d0c41667d28ece7b6 Mon Sep 17 00:00:00 2001 From: ferres Date: Tue, 23 Jul 2024 16:01:15 +0000 Subject: [PATCH] Revert "Evaluate the rv.shape directly" This reverts commit 0ca9f68b91921ab014d1ab2f96c695e0ed669a62. --- pymc_experimental/model/transforms/autoreparam.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/model/transforms/autoreparam.py b/pymc_experimental/model/transforms/autoreparam.py index 4c88788ef..75b9966c0 100644 --- a/pymc_experimental/model/transforms/autoreparam.py +++ b/pymc_experimental/model/transforms/autoreparam.py @@ -176,8 +176,12 @@ def vip_reparam_node( ) -> Tuple[ModelDeterministic, ModelNamed]: if not isinstance(node.op, RandomVariable | SymbolicRandomVariable): raise TypeError("Op should be RandomVariable type") - rv = node.default_output() - rv_shape = rv.shape.eval(mode="FAST_COMPILE") + _, size, *_ = node.inputs + eval_size = size.eval(mode="FAST_COMPILE") + if eval_size is not None: + rv_shape = tuple(eval_size) + else: + rv_shape = () lam_name = f"{name}::lam_logit__" _log.debug(f"Creating {lam_name} with shape of {rv_shape}") logit_lam_ = pytensor.shared(