diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index acc81e2cb..acd3a94a6 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -231,7 +231,7 @@ def _update_params(params, new_params, prior, prefix=""): A helper to recursively set prior to new_params. """ for name, item in params.items(): - flatten_name = ".".join([prefix, name]) if prefix else name + flatten_name = ".".join([str(prefix), str(name)]) if prefix else str(name) if isinstance(item, dict): assert not isinstance(prior, dict) or flatten_name not in prior new_item = new_params[name] diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 6685e9580..629f6fe33 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -480,6 +480,53 @@ def model(data, labels=None): assert "nn/w" in samples +@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9") +def test_random_nnx_module_mcmc_sequence_params(): + from flax import nnx + + class MLP(nnx.Module): + def __init__(self, din, dout, hidden_layers, *, rngs, activation=jax.nn.relu): + self.activation = activation + self.layers = [] + layer_dims = [din] + hidden_layers + [dout] + for in_dim, out_dim in zip(layer_dims[:-1], layer_dims[1:]): + self.layers.append(nnx.Linear(in_dim, out_dim, rngs=rngs)) + + def __call__(self, x): + for layer in self.layers[:-1]: + x = self.activation(layer(x)) + return self.layers[-1](x) + + N, dim = 3000, 3 + data = random.normal(random.PRNGKey(0), (N, dim)) + true_coefs = np.arange(1.0, dim + 1.0) + logits = np.sum(true_coefs * data, axis=-1) + labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) + + rng_key = random.PRNGKey(0) + nn_module = MLP( + din=dim, dout=1, hidden_layers=[8, 8], rngs=nnx.Rngs(params=rng_key) + ) + + def prior(name, shape): + return dist.Cauchy() if name == "bias" else dist.Normal() + + def model(data, labels=None): + # Use the pre-initialized module with eager initialization + nn = random_nnx_module("nn", nn_module, prior=prior) + logits = nn(data).squeeze(-1) + return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + nuts_kernel = NUTS(model) + mcmc = MCMC(nuts_kernel, num_warmup=1, num_samples=1, progress_bar=False) + mcmc.run(random.PRNGKey(0), data, labels) + samples = mcmc.get_samples() + + # check both layers have parameters in the samples + assert "nn/layers.0.bias" in samples + assert "nn/layers.1.bias" in samples + + @pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9") def test_eqx_module(): import equinox as eqx @@ -606,3 +653,59 @@ def model(data, labels=None): samples = mcmc.get_samples() assert "nn/bias" in samples assert "nn/weight" in samples + + +@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9") +def test_random_eqx_module_mcmc_sequence_params(): + import equinox as eqx + + class MLP(eqx.Module): + layers: list + + def __init__( + self, + in_size: int, + out_size: int, + hidden_layers: list[int], + key: jax.random.PRNGKey, + ): + keys = jax.random.split(key, len(hidden_layers)) + self.layers = [] + + # Create all linear layers + self.layers = [] + layer_dims = [in_size] + list(hidden_layers) + [out_size] + for i, (in_dim, out_dim) in enumerate(zip(layer_dims[:-1], layer_dims[1:])): + self.layers.append(eqx.nn.Linear(in_dim, out_dim, key=keys[i])) + + def __call__(self, x): + for layer in self.layers[:-1]: + x = jax.nn.relu(layer(x)) + return self.layers[-1](x) # Final layer, no activation + + N, dim = 3000, 3 + data = random.normal(random.PRNGKey(0), (N, dim)) + true_coefs = np.arange(1.0, dim + 1.0) + logits = np.sum(true_coefs * data, axis=-1) + labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) + + rng_key = random.PRNGKey(0) + nn_module = MLP(in_size=dim, out_size=1, hidden_layers=[8, 8], key=rng_key) + + def prior(name, shape): + return dist.Cauchy() if name == "bias" else dist.Normal() + + def model(data, labels=None): + # Use the pre-initialized module with eager initialization + nn = random_eqx_module("nn", nn_module, prior=prior) + logits = jax.vmap(nn)(data).squeeze(-1) + return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + nuts_kernel = NUTS(model) + mcmc = MCMC(nuts_kernel, num_warmup=1, num_samples=1, progress_bar=False) + mcmc.run(random.PRNGKey(0), data, labels) + samples = mcmc.get_samples() + + # check both layers have parameters in the samples + assert "nn/layers[0].bias" in samples + assert "nn/layers[1].bias" in samples