Skip to content

Commit d1cab14

Browse files
committed
Fix jax/flax management
1 parent 2da19b8 commit d1cab14

File tree

4 files changed

+2
-7
lines changed

4 files changed

+2
-7
lines changed

keras/utils/jax_layer.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88
from keras.saving import serialization_lib
99
from keras.utils import shape_utils
1010
from keras.utils import tracking
11-
12-
try:
13-
import jax
14-
except ImportError:
15-
jax = None
11+
from keras.utils.module_utils import jax
1612

1713

1814
@keras_export("keras.layers.JaxLayer")

keras/utils/module_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ def __getattr__(self, name):
4040
gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
4141
tensorflow_io = LazyModule("tensorflow_io")
4242
scipy = LazyModule("scipy")
43+
jax = LazyModule("jax")

requirements-tensorflow-cuda.txt

-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ torchvision>=0.16.0
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]
11-
flax
1211

1312
-r requirements-common.txt

requirements-torch-cuda.txt

-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ torchvision==0.17.1+cu121
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]
11-
flax
1211

1312
-r requirements-common.txt

0 commit comments

Comments
 (0)