Skip to content

get_model_relations fails when parameters are set with lambda functions #2064

@mochar

Description

@mochar

Bug Description

get_model_relations fails when parameters values are set to be lambda functions, see reproduce step below.

There is some logic in get_model_relations to handle this problem in sample and deterministic cases, but not for the param case:

        # Work around an issue where jax.eval_shape does not work
        # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
        # Here we will remove `fn` and store its name in the trace.
        for name, site in trace.items():
            if site["type"] == "sample":
                site["fn_name"] = _get_dist_name(site.pop("fn"))
            elif site["type"] == "deterministic":
                site["fn_name"] = "Deterministic"

Steps to Reproduce

from numpyro.infer.inspect import get_model_relations
import numpyro
def guide():
    numpyro.param('p', lambda _: 1.)
get_model_relations(guide)

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[14], line 1
----> 1 get_model_relations(g)

File ~/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:326, in get_model_relations(model, model_args, model_kwargs)
    323     return PytreeTrace(trace)
    325 # We use eval_shape to avoid any array computation.
--> 326 trace = jax.eval_shape(get_trace).trace
    327 obs_sites = [
    328     name
    329     for name, site in trace.items()
    330     if site["type"] == "sample" and site["is_observed"]
    331 ]
    332 sample_dist = {
    333     name: site["fn_name"]
    334     for name, site in trace.items()
    335     if site["type"] in ["sample", "deterministic"]
    336 }

    [... skipping hidden 11 frame]

File ~/.venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py:2316, in _check_returned_jaxtypes(dbg, out_tracers)
   2314 else:
   2315   extra = ''
-> 2316 raise TypeError(
   2317 f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a "
   2318 f"value of type {type(x)}{extra}, which is not a valid JAX type") from None

TypeError: function get_trace at /home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:307 traced for jit returned a value of type <class 'function'> at output component [0]['p']['args'][0], which is not a valid JAX type

Expected Behavior

A clear and concise description of what you expected to happen.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions