From dd401279fc8be01dea901153225e05bf38e44ac3 Mon Sep 17 00:00:00 2001 From: Mat Kelcey Date: Sun, 31 Jan 2021 20:40:04 +1100 Subject: [PATCH] jax.nn.functions has moved to jax.nn --- models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models.py b/models.py index a104f9c..323e6c9 100644 --- a/models.py +++ b/models.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from jax import random, lax, vmap from jax.nn.initializers import glorot_normal, he_normal -from jax.nn.functions import gelu +from jax.nn import gelu from functools import partial import objax from objax.variable import TrainVar, StateVar