-
Notifications
You must be signed in to change notification settings - Fork 270
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
When using an random_nnx_module, if the nn_module has parameters stored within a list (like layers in an MLP), numpyro's _update_params will break because each layer of the list is assigned an integer index as its name - this causes the flatten_name line to break when trying to concatenate a string with an int. should be an easy fix
Traceback
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[174], line 45
43 kernel = NUTS(model)
44 mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2)
---> 45 mcmc.run(random.PRNGKey(0), X=X, y=y)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:708, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
706 states, last_state = _laxmap(partial_map_fn, map_args)
707 elif self.chain_method == "parallel":
--> 708 states, last_state = pmap(partial_map_fn)(map_args)
709 elif callable(self.chain_method):
710 states, last_state = self.chain_method(partial_map_fn)(map_args)
[... skipping hidden 14 frame]
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:465, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
463 # Check if _sample_fn is None, then we need to initialize the sampler.
464 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 465 new_init_state = self.sampler.init(
466 rng_key,
467 self.num_warmup,
468 init_params,
469 model_args=args,
470 model_kwargs=kwargs,
471 )
472 init_state = new_init_state if init_state is None else init_state
473 sample_fn, postprocess_fn = self._get_cached_fns()
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:751, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
746 # vectorized
747 else:
748 rng_key, rng_key_init_model = jnp.swapaxes(
749 vmap(random.split)(rng_key), 0, 1
750 )
--> 751 init_params = self._init_state(
752 rng_key_init_model, model_args, model_kwargs, init_params
753 )
754 if self._potential_fn and init_params is None:
755 raise ValueError(
756 "Valid value of `init_params` must be provided with `potential_fn`."
757 )
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:695, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
688 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
689 if self._model is not None:
690 (
691 new_init_params,
692 potential_fn,
693 postprocess_fn,
694 model_trace,
--> 695 ) = initialize_model(
696 rng_key,
697 self._model,
698 dynamic_args=True,
699 init_strategy=self._init_strategy,
700 model_args=model_args,
701 model_kwargs=model_kwargs,
702 forward_mode_differentiation=self._forward_mode_differentiation,
703 )
704 if init_params is None:
705 init_params = new_init_params
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:688, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
678 model_kwargs = {} if model_kwargs is None else model_kwargs
679 substituted_model = substitute(
680 seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
681 substitute_fn=init_strategy,
682 )
683 (
684 inv_transforms,
685 replay_model,
686 has_enumerate_support,
687 model_trace,
--> 688 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
690 for name, site in model_trace.items():
691 if (
692 site["type"] == "sample"
693 and isinstance(site["fn"], dist.Delta)
694 and not site["is_observed"]
695 ):
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:482, in _get_model_transforms(model, model_args, model_kwargs)
480 def _get_model_transforms(model, model_args=(), model_kwargs=None):
481 model_kwargs = {} if model_kwargs is None else model_kwargs
--> 482 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
483 inv_transforms = {}
484 # model code may need to be replayed in the presence of deterministic sites
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs)
183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
184 """
185 Run the wrapped callable and return the recorded trace.
186
(...)
189 :return: `OrderedDict` containing the execution trace.
190 """
--> 191 self(*args, **kwargs)
192 return self.trace
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
119 return self
120 with self:
--> 121 return self.fn(*args, **kwargs)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
119 return self
120 with self:
--> 121 return self.fn(*args, **kwargs)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/handlers.py:846, in seed.__call__(self, *args, **kwargs)
842 cloned_seeded_fn = seed(
843 self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
844 )
845 cloned_seeded_fn.stateful = True
--> 846 return cloned_seeded_fn.__call__(*args, **kwargs)
847 return super().__call__(*args, **kwargs)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/handlers.py:847, in seed.__call__(self, *args, **kwargs)
845 cloned_seeded_fn.stateful = True
846 return cloned_seeded_fn.__call__(*args, **kwargs)
--> 847 return super().__call__(*args, **kwargs)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
119 return self
120 with self:
--> 121 return self.fn(*args, **kwargs)
Cell In[174], line 35, in model(X, y)
33 def model(X, y=None):
34 sigma= numpyro.sample("sigma", dist.HalfNormal(2.5))
---> 35 nn = random_nnx_module(
36 "nn",
37 MLP(X.shape[1], 1, hidden_layers=[8,8], rngs=nnx.Rngs(0)),
38 prior={"bias": dist.Cauchy(), "kernel":dist.Normal()}
39 )
40 mu = nn(X).squeeze(-1)
41 return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:566, in random_nnx_module(name, nn_module, prior)
563 new_params = deepcopy(params)
565 with numpyro.handlers.scope(prefix=name):
--> 566 _update_params(params, new_params, prior)
568 return partial(apply_fn, new_params, *other_args, **keywords)
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:238, in _update_params(params, new_params, prior, prefix)
236 assert not isinstance(prior, dict) or flatten_name not in prior
237 new_item = new_params[name]
--> 238 _update_params(item, new_item, prior, prefix=flatten_name)
239 elif (not isinstance(prior, dict)) or flatten_name in prior:
240 if isinstance(params[name], ParamShape):
File ~/repo/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:234, in _update_params(params, new_params, prior, prefix)
230 """
231 A helper to recursively set prior to new_params.
232 """
233 for name, item in params.items():
--> 234 flatten_name = ".".join([prefix, name]) if prefix else name
235 if isinstance(item, dict):
236 assert not isinstance(prior, dict) or flatten_name not in prior
TypeError: sequence item 1: expected str instance, int found
Steps to Reproduce
Steps to reproduce the behavior.
import numpy as np
from flax import nnx
from jax import random
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
class MLP(nnx.Module):
def __init__(self, din, dout, hidden_layers, *, rngs, activation=jax.nn.relu):
self.activation = activation
self.layers = []
# Create list of layer sizes including input and output
layer_dims = [din] + hidden_layers + [dout]
for (in_dim, out_dim) in zip(layer_dims[:-1], layer_dims[1:]):
linear = nnx.Linear(in_dim, out_dim, rngs=rngs,)
self.layers.append(linear)
def __call__(self, x):
for layer in self.layers[:-1]:
x = self.activation(layer(x))
x = self.layers[-1](x) # Final layer, no activation
return x
X = rng.normal(size=(N, 2))
beta= np.array([0.2, 0.8])
y = rng.normal(1.2 + jnp.dot(X, beta), 0.1)
def model(X, y=None):
sigma= numpyro.sample("sigma", dist.HalfNormal(2.5))
nn = random_nnx_module(
"nn",
MLP(X.shape[1], 1, hidden_layers=[8,8], rngs=nnx.Rngs(0)),
prior={"bias": dist.Cauchy(), "kernel":dist.Normal()}
)
mu = nn(X).squeeze(-1)
return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), X=X, y=y)
Expected Behavior
A clear and concise description of what you expected to happen.
A model with random_nnx_module and random_eqx_module should be able to fit without an error if the neural net module uses lists to store layers, and the priors supplied should be properly getting passed to those layers
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working