- 
                Notifications
    You must be signed in to change notification settings 
- Fork 271
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
When replicating the "NNX and NumPyro Integration" example using equinox as a drop-in replacement for NNX, the model can be fit and behaves as expected for inference but cannot be visualized with render_model due to a TypeError in get_dependencies.
Steps to Reproduce
import jax
import jax.numpy as jnp
from jax import random
import jax.tree_util as jtu
import equinox as eqx
import numpyro
from numpyro.contrib.module import eqx_module, random_eqx_module
import numpyro.distributions as dist
rng_key = random.PRNGKey(seed=42)
n = 32 * 10
rng_key, rng_subkey = random.split(rng_key)
x = jnp.linspace(1, jnp.pi, n)
x_train = x[..., None]
class LocMLP(eqx.Module):
    """3-layer Multi-layer perceptron for the mean."""
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    linear3: eqx.nn.Linear
    def __init__(self, din: int, dmid: int, dout: int, *, key):
        key1, key2, key3 = random.split(key, 3)
        self.linear1 = eqx.nn.Linear(din, dmid, key=key1)
        self.linear2 = eqx.nn.Linear(dmid, dmid, key=key2)
        self.linear3 = eqx.nn.Linear(dmid, dout, key=key3)
    def __call__(self, x):
        x = self.linear1(x)
        x = jax.nn.sigmoid(x)
        x = self.linear2(x)
        x = jax.nn.sigmoid(x)
        x = self.linear3(x)
        return x
class ScaleMLP(eqx.Module):
    """Single-layer MLP for the standard deviation."""
    linear: eqx.nn.Linear
    def __init__(self, *, key) -> None:
        self.linear = eqx.nn.Linear(1, 1, key=key)
    def __call__(self, x):
        x = self.linear(x)
        return jax.nn.softplus(x)
rng_key, mu_key, sigma_key = random.split(rng_key, 3)
mu_nn_module = LocMLP(din=1, dmid=8, dout=1, key=mu_key)
sigma_nn_module = ScaleMLP(key=sigma_key)
[jtu.keystr(path)[1:] for path, _ in jtu.tree_leaves_with_path(sigma_nn_module)]['linear.weight', 'linear.bias']
def model(x):
    mu_nn = eqx_module("mu_nn", mu_nn_module)
    sigma_nn = random_eqx_module(
        "sigma_nn",
        sigma_nn_module,
        prior={
            "linear.weight": dist.HalfNormal(scale=1),
            "linear.bias": dist.Normal(loc=0, scale=1),
        },
    )
    mu = numpyro.deterministic("mu", jax.vmap(mu_nn)(x).squeeze())
    sigma = numpyro.deterministic("sigma", jax.vmap(sigma_nn)(x).squeeze())
    with numpyro.plate("data", x.shape[0]):
        numpyro.sample("likelihood", dist.Normal(loc=mu, scale=sigma))
## Everything up to here works as expected, the following does not work:
numpyro.render_model(
    model=model,
    model_args=(x_train,),
    render_distributions=True,
    render_params=True,
)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[15], [line 1](vscode-notebook-cell:?execution_count=15&line=1)
----> [1](vscode-notebook-cell:?execution_count=15&line=1) numpyro.render_model(
      2     model=model,
      3     model_args=(x_train,),
      4     render_distributions=True,
      5     render_params=True,
      6 )
File numpyro\infer\inspect.py:626, in render_model(model, model_args, model_kwargs, filename, render_distributions, render_params)
    603 def render_model(
    604     model,
    605     model_args=None,
   (...)    609     render_params=False,
    610 ):
    611     """
    612     Wrap all functions needed to automatically render a model.
    613 
   (...)    624     :param bool render_params: Whether to show params in the plot.
    625     """
--> [626](file:numpyro/infer/inspect.py:626)     relations = get_model_relations(
    627         model,
    628         model_args=model_args,
    629         model_kwargs=model_kwargs,
    630     )
    631     graph_spec = generate_graph_specification(relations, render_params=render_params)
    632     graph = render_graph(graph_spec, render_distributions=render_distributions)
File numpyro\infer\inspect.py:326, in get_model_relations(model, model_args, model_kwargs)
    323     return PytreeTrace(trace)
    325 # We use eval_shape to avoid any array computation.
--> [326](numpyro/infer/inspect.py:326) trace = jax.eval_shape(get_trace).trace
    327 obs_sites = [
    328     name
    329     for name, site in trace.items()
    330     if site["type"] == "sample" and site["is_observed"]
    331 ]
    332 sample_dist = {
    333     name: site["fn_name"]
    334     for name, site in trace.items()
    335     if site["type"] in ["sample", "deterministic"]
    336 }
    [... skipping hidden 11 frame]
File jax\_src\interpreters\partial_eval.py:[2407](file:jax/_src/interpreters/partial_eval.py:2407), in _check_returned_jaxtypes(dbg, out_tracers)
   2405 else:
   2406   extra = ''
-> 2407 raise TypeError(
   2408 f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a "
   2409 f"value of type {type(x)}{extra}, which is not a valid JAX type") from None
TypeError: function get_trace at numpyro\infer\inspect.py:307 traced for jit returned a value of type <class 'function'>, which is not a valid JAX type
Expected Behavior
Should be identical to the NNX example; only difference being the name of the parameters in sigma_nn being 'linear.weight' instead of 'linear.kernel'.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working