diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index 4c197e634..451dad9cb 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -320,6 +320,9 @@ def get_trace(): site["fn_name"] = _get_dist_name(site.pop("fn")) elif site["type"] == "deterministic": site["fn_name"] = "Deterministic" + elif site["type"] == "param": + # Remove lambda functions from param args to avoid jax.eval_shape issues + site.pop("args", None) return PytreeTrace(trace) # We use eval_shape to avoid any array computation. diff --git a/test/test_model_rendering.py b/test/test_model_rendering.py index a543add90..b100e1296 100644 --- a/test/test_model_rendering.py +++ b/test/test_model_rendering.py @@ -144,3 +144,42 @@ def model(): render_model(model, filename="graph.png") assert os.path.exists("graph.png") os.remove("graph.png") + + +def test_param_with_lambda_function(): + """ + Test that get_model_relations works when params are initialized with lambda functions. + Regression test for issue #2064. + """ + + def guide(): + numpyro.param("p", lambda _: 1.0) + + # This should not raise a TypeError about lambda functions not being valid JAX types + relations = get_model_relations(guide) + + # Verify the param is captured correctly + assert "p" in relations["param_constraint"] + assert relations["param_constraint"]["p"] == "" + assert relations["sample_sample"] == {} + assert relations["sample_param"] == {} + assert relations["sample_dist"] == {} + assert relations["observed"] == [] + + +def test_param_with_lambda_and_sample(): + """ + Test that get_model_relations works with both params (lambda) and sample sites. + """ + + def model(): + p = numpyro.param("p", lambda _: jnp.array([0.5, 0.5])) + numpyro.sample("x", dist.Categorical(p)) + + relations = get_model_relations(model) + + # Verify both param and sample are captured + assert "p" in relations["param_constraint"] + assert "x" in relations["sample_dist"] + assert relations["sample_dist"]["x"] == "CategoricalProbs" + assert "p" in relations["sample_param"]["x"]