-
Notifications
You must be signed in to change notification settings - Fork 269
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
get_model_relations
fails when parameters values are set to be lambda functions, see reproduce step below.
There is some logic in get_model_relations
to handle this problem in sample and deterministic cases, but not for the param case:
# Work around an issue where jax.eval_shape does not work
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
# Here we will remove `fn` and store its name in the trace.
for name, site in trace.items():
if site["type"] == "sample":
site["fn_name"] = _get_dist_name(site.pop("fn"))
elif site["type"] == "deterministic":
site["fn_name"] = "Deterministic"
Steps to Reproduce
from numpyro.infer.inspect import get_model_relations
import numpyro
def guide():
numpyro.param('p', lambda _: 1.)
get_model_relations(guide)
Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[14], line 1
----> 1 get_model_relations(g)
File ~/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:326, in get_model_relations(model, model_args, model_kwargs)
323 return PytreeTrace(trace)
325 # We use eval_shape to avoid any array computation.
--> 326 trace = jax.eval_shape(get_trace).trace
327 obs_sites = [
328 name
329 for name, site in trace.items()
330 if site["type"] == "sample" and site["is_observed"]
331 ]
332 sample_dist = {
333 name: site["fn_name"]
334 for name, site in trace.items()
335 if site["type"] in ["sample", "deterministic"]
336 }
[... skipping hidden 11 frame]
File ~/.venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py:2316, in _check_returned_jaxtypes(dbg, out_tracers)
2314 else:
2315 extra = ''
-> 2316 raise TypeError(
2317 f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a "
2318 f"value of type {type(x)}{extra}, which is not a valid JAX type") from None
TypeError: function get_trace at /home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:307 traced for jit returned a value of type <class 'function'> at output component [0]['p']['args'][0], which is not a valid JAX type
Expected Behavior
A clear and concise description of what you expected to happen.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working