Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _update_params(params, new_params, prior, prefix=""):
A helper to recursively set prior to new_params.
"""
for name, item in params.items():
flatten_name = ".".join([prefix, name]) if prefix else name
flatten_name = ".".join([str(prefix), str(name)]) if prefix else str(name)
if isinstance(item, dict):
assert not isinstance(prior, dict) or flatten_name not in prior
new_item = new_params[name]
Expand Down
103 changes: 103 additions & 0 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,53 @@ def model(data, labels=None):
assert "nn/w" in samples


@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
def test_random_nnx_module_mcmc_sequence_params():
from flax import nnx

class MLP(nnx.Module):
def __init__(self, din, dout, hidden_layers, *, rngs, activation=jax.nn.relu):
self.activation = activation
self.layers = []
layer_dims = [din] + hidden_layers + [dout]
for in_dim, out_dim in zip(layer_dims[:-1], layer_dims[1:]):
self.layers.append(nnx.Linear(in_dim, out_dim, rngs=rngs))

def __call__(self, x):
for layer in self.layers[:-1]:
x = self.activation(layer(x))
return self.layers[-1](x)

N, dim = 3000, 3
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = np.arange(1.0, dim + 1.0)
logits = np.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

rng_key = random.PRNGKey(0)
nn_module = MLP(
din=dim, dout=1, hidden_layers=[8, 8], rngs=nnx.Rngs(params=rng_key)
)

def prior(name, shape):
return dist.Cauchy() if name == "bias" else dist.Normal()

def model(data, labels=None):
# Use the pre-initialized module with eager initialization
nn = random_nnx_module("nn", nn_module, prior=prior)
logits = nn(data).squeeze(-1)
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=1, num_samples=1, progress_bar=False)
mcmc.run(random.PRNGKey(0), data, labels)
samples = mcmc.get_samples()

# check both layers have parameters in the samples
assert "nn/layers.0.bias" in samples
assert "nn/layers.1.bias" in samples


@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
def test_eqx_module():
import equinox as eqx
Expand Down Expand Up @@ -606,3 +653,59 @@ def model(data, labels=None):
samples = mcmc.get_samples()
assert "nn/bias" in samples
assert "nn/weight" in samples


@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
def test_random_eqx_module_mcmc_sequence_params():
import equinox as eqx

class MLP(eqx.Module):
layers: list

def __init__(
self,
in_size: int,
out_size: int,
hidden_layers: list[int],
key: jax.random.PRNGKey,
):
keys = jax.random.split(key, len(hidden_layers))
self.layers = []

# Create all linear layers
self.layers = []
layer_dims = [in_size] + list(hidden_layers) + [out_size]
for i, (in_dim, out_dim) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
self.layers.append(eqx.nn.Linear(in_dim, out_dim, key=keys[i]))

def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x) # Final layer, no activation

N, dim = 3000, 3
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = np.arange(1.0, dim + 1.0)
logits = np.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

rng_key = random.PRNGKey(0)
nn_module = MLP(in_size=dim, out_size=1, hidden_layers=[8, 8], key=rng_key)

def prior(name, shape):
return dist.Cauchy() if name == "bias" else dist.Normal()

def model(data, labels=None):
# Use the pre-initialized module with eager initialization
nn = random_eqx_module("nn", nn_module, prior=prior)
logits = jax.vmap(nn)(data).squeeze(-1)
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=1, num_samples=1, progress_bar=False)
mcmc.run(random.PRNGKey(0), data, labels)
samples = mcmc.get_samples()

# check both layers have parameters in the samples
assert "nn/layers[0].bias" in samples
assert "nn/layers[1].bias" in samples