Skip to content

random_nnx_module (and potentially random_eqx_module) cant handle models with list layers #2022

@kylejcaron

Description

@kylejcaron

Bug Description

When using an random_nnx_module, if the nn_module has parameters stored within a list (like layers in an MLP), numpyro's _update_params will break because each layer of the list is assigned an integer index as its name - this causes the flatten_name line to break when trying to concatenate a string with an int. should be an easy fix

Traceback
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[174], line 45
     43 kernel = NUTS(model)
     44 mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2)
---> 45 mcmc.run(random.PRNGKey(0), X=X, y=y)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:708, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    706     states, last_state = _laxmap(partial_map_fn, map_args)
    707 elif self.chain_method == "parallel":
--> 708     states, last_state = pmap(partial_map_fn)(map_args)
    709 elif callable(self.chain_method):
    710     states, last_state = self.chain_method(partial_map_fn)(map_args)

    [... skipping hidden 14 frame]

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:465, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    463 # Check if _sample_fn is None, then we need to initialize the sampler.
    464 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 465     new_init_state = self.sampler.init(
    466         rng_key,
    467         self.num_warmup,
    468         init_params,
    469         model_args=args,
    470         model_kwargs=kwargs,
    471     )
    472     init_state = new_init_state if init_state is None else init_state
    473 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:751, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    746 # vectorized
    747 else:
    748     rng_key, rng_key_init_model = jnp.swapaxes(
    749         vmap(random.split)(rng_key), 0, 1
    750     )
--> 751 init_params = self._init_state(
    752     rng_key_init_model, model_args, model_kwargs, init_params
    753 )
    754 if self._potential_fn and init_params is None:
    755     raise ValueError(
    756         "Valid value of `init_params` must be provided with `potential_fn`."
    757     )

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:695, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    688 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    689     if self._model is not None:
    690         (
    691             new_init_params,
    692             potential_fn,
    693             postprocess_fn,
    694             model_trace,
--> 695         ) = initialize_model(
    696             rng_key,
    697             self._model,
    698             dynamic_args=True,
    699             init_strategy=self._init_strategy,
    700             model_args=model_args,
    701             model_kwargs=model_kwargs,
    702             forward_mode_differentiation=self._forward_mode_differentiation,
    703         )
    704         if init_params is None:
    705             init_params = new_init_params

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:688, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    678 model_kwargs = {} if model_kwargs is None else model_kwargs
    679 substituted_model = substitute(
    680     seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
    681     substitute_fn=init_strategy,
    682 )
    683 (
    684     inv_transforms,
    685     replay_model,
    686     has_enumerate_support,
    687     model_trace,
--> 688 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
    690 for name, site in model_trace.items():
    691     if (
    692         site["type"] == "sample"
    693         and isinstance(site["fn"], dist.Delta)
    694         and not site["is_observed"]
    695     ):

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/infer/util.py:482, in _get_model_transforms(model, model_args, model_kwargs)
    480 def _get_model_transforms(model, model_args=(), model_kwargs=None):
    481     model_kwargs = {} if model_kwargs is None else model_kwargs
--> 482     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    483     inv_transforms = {}
    484     # model code may need to be replayed in the presence of deterministic sites

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs)
    183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
    184     """
    185     Run the wrapped callable and return the recorded trace.
    186 
   (...)
    189     :return: `OrderedDict` containing the execution trace.
    190     """
--> 191     self(*args, **kwargs)
    192     return self.trace

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/handlers.py:846, in seed.__call__(self, *args, **kwargs)
    842     cloned_seeded_fn = seed(
    843         self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
    844     )
    845     cloned_seeded_fn.stateful = True
--> 846     return cloned_seeded_fn.__call__(*args, **kwargs)
    847 return super().__call__(*args, **kwargs)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/handlers.py:847, in seed.__call__(self, *args, **kwargs)
    845     cloned_seeded_fn.stateful = True
    846     return cloned_seeded_fn.__call__(*args, **kwargs)
--> 847 return super().__call__(*args, **kwargs)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

Cell In[174], line 35, in model(X, y)
     33 def model(X, y=None):
     34     sigma= numpyro.sample("sigma", dist.HalfNormal(2.5))
---> 35     nn = random_nnx_module(
     36         "nn", 
     37         MLP(X.shape[1], 1, hidden_layers=[8,8], rngs=nnx.Rngs(0)),
     38         prior={"bias": dist.Cauchy(), "kernel":dist.Normal()}
     39     )
     40     mu = nn(X).squeeze(-1)
     41     return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:566, in random_nnx_module(name, nn_module, prior)
    563 new_params = deepcopy(params)
    565 with numpyro.handlers.scope(prefix=name):
--> 566     _update_params(params, new_params, prior)
    568 return partial(apply_fn, new_params, *other_args, **keywords)

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:238, in _update_params(params, new_params, prior, prefix)
    236     assert not isinstance(prior, dict) or flatten_name not in prior
    237     new_item = new_params[name]
--> 238     _update_params(item, new_item, prior, prefix=flatten_name)
    239 elif (not isinstance(prior, dict)) or flatten_name in prior:
    240     if isinstance(params[name], ParamShape):

File ~/repo/.venv/lib/python3.11/site-packages/numpyro/contrib/module.py:234, in _update_params(params, new_params, prior, prefix)
    230 """
    231 A helper to recursively set prior to new_params.
    232 """
    233 for name, item in params.items():
--> 234     flatten_name = ".".join([prefix, name]) if prefix else name
    235     if isinstance(item, dict):
    236         assert not isinstance(prior, dict) or flatten_name not in prior

TypeError: sequence item 1: expected str instance, int found

Steps to Reproduce

Steps to reproduce the behavior.

import numpy as np
from flax import nnx

from jax import random
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive


class MLP(nnx.Module):
    def __init__(self, din, dout, hidden_layers, *, rngs, activation=jax.nn.relu):
        self.activation = activation
        self.layers = []

        # Create list of layer sizes including input and output
        layer_dims = [din] + hidden_layers + [dout]
        for (in_dim, out_dim) in zip(layer_dims[:-1], layer_dims[1:]):
            linear = nnx.Linear(in_dim, out_dim, rngs=rngs,)
            self.layers.append(linear)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        x = self.layers[-1](x)  # Final layer, no activation
        return x


X = rng.normal(size=(N, 2))
beta=  np.array([0.2, 0.8])
y = rng.normal(1.2 + jnp.dot(X, beta), 0.1)


def model(X, y=None):
    sigma= numpyro.sample("sigma", dist.HalfNormal(2.5))
    nn = random_nnx_module(
        "nn", 
        MLP(X.shape[1], 1, hidden_layers=[8,8], rngs=nnx.Rngs(0)),
        prior={"bias": dist.Cauchy(), "kernel":dist.Normal()}
    )
    mu = nn(X).squeeze(-1)
    return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), X=X, y=y)

Expected Behavior

A clear and concise description of what you expected to happen.

A model with random_nnx_module and random_eqx_module should be able to fit without an error if the neural net module uses lists to store layers, and the priors supplied should be properly getting passed to those layers

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions