From d1cbd09c47f582d91e6ab0f214dba8352c476512 Mon Sep 17 00:00:00 2001 From: yixiaoer Date: Thu, 8 Feb 2024 22:13:31 +0800 Subject: [PATCH] Update optax-101.ipynb --- docs/optax-101.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/optax-101.ipynb b/docs/optax-101.ipynb index 0c7db36a..df218708 100644 --- a/docs/optax-101.ipynb +++ b/docs/optax-101.ipynb @@ -306,7 +306,7 @@ "}\n", "\n", "\n", - "def net(x: jnp.ndarray, params: jnp.ndarray) -\u003e jnp.ndarray:\n", + "def net(x: jnp.ndarray, params: optax.Params) -\u003e jnp.ndarray:\n", " x = jnp.dot(x, params['hidden'])\n", " x = jax.nn.relu(x)\n", " x = jnp.dot(x, params['output'])\n",