From c2a08e0a464224d851c91ce0976ee64edd3b955f Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 4 Sep 2025 19:25:00 +0200 Subject: [PATCH 1/2] ensure compatibility --- numpyro/contrib/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index ab902cda2..92675e0e5 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -498,7 +498,7 @@ def apply_fn(params, *call_args, **call_kwargs): if mutable_holder: nnx.replace_by_pure_dict(mutable_state, mutable_holder["state"]) - model = nnx.merge(graph_def, params_state, mutable_state) + model = nnx.merge(graph_def, params_state, mutable_state, copy=True) model_call = model(*call_args, **call_kwargs) From 198c73f01cce3ce472c3d79c3dcee80435ae2274 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 5 Sep 2025 21:30:22 +0200 Subject: [PATCH 2/2] rm xfail --- test/contrib/test_module.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 3f851b896..0ccf909f9 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -385,9 +385,6 @@ def nnx_model_eager(x, y): @pytest.mark.parametrize( argnames="batchnorm", argvalues=[True, False], ids=["batchnorm", "no_batchnorm"] ) -@pytest.mark.xfail( - reason="Temporary marking to pass CI. Bug fixed in https://github.com/pyro-ppl/numpyro/pull/2067" -) def test_nnx_state_dropout_smoke(dropout, batchnorm): from flax import nnx