Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions numpyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def get_trace():
site["fn_name"] = _get_dist_name(site.pop("fn"))
elif site["type"] == "deterministic":
site["fn_name"] = "Deterministic"
elif site["type"] == "param":
# Remove lambda functions from param args to avoid jax.eval_shape issues
site.pop("args", None)
return PytreeTrace(trace)

# We use eval_shape to avoid any array computation.
Expand Down
39 changes: 39 additions & 0 deletions test/test_model_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,42 @@ def model():
render_model(model, filename="graph.png")
assert os.path.exists("graph.png")
os.remove("graph.png")


def test_param_with_lambda_function():
"""
Test that get_model_relations works when params are initialized with lambda functions.
Regression test for issue #2064.
"""

def guide():
numpyro.param("p", lambda _: 1.0)

# This should not raise a TypeError about lambda functions not being valid JAX types
relations = get_model_relations(guide)

# Verify the param is captured correctly
assert "p" in relations["param_constraint"]
assert relations["param_constraint"]["p"] == ""
assert relations["sample_sample"] == {}
assert relations["sample_param"] == {}
assert relations["sample_dist"] == {}
assert relations["observed"] == []


def test_param_with_lambda_and_sample():
"""
Test that get_model_relations works with both params (lambda) and sample sites.
"""

def model():
p = numpyro.param("p", lambda _: jnp.array([0.5, 0.5]))
numpyro.sample("x", dist.Categorical(p))

relations = get_model_relations(model)

# Verify both param and sample are captured
assert "p" in relations["param_constraint"]
assert "x" in relations["sample_dist"]
assert relations["sample_dist"]["x"] == "CategoricalProbs"
assert "p" in relations["sample_param"]["x"]