Skip to content

Commit f0685b2

Browse files
committed
Test against Keras 3 (keras-team#1273)
1 parent 8b9cdec commit f0685b2

File tree

8 files changed

+19
-15
lines changed

8 files changed

+19
-15
lines changed

.github/workflows/actions.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77
types: [created]
88
jobs:
99
build:
10-
name: Test the code with tf.keras
10+
name: Test the code with Keras 2
1111
runs-on: ubuntu-latest
1212
steps:
1313
- uses: actions/checkout@v2
@@ -29,7 +29,8 @@ jobs:
2929
${{ runner.os }}-pip-
3030
- name: Install dependencies
3131
run: |
32-
pip install -r requirements.txt --progress-bar off
32+
pip install -r requirements-common.txt --progress-bar off
33+
pip install tensorflow-text==2.14 tensorflow==2.14 keras-core
3334
pip install --no-deps -e "." --progress-bar off
3435
- name: Test with pytest
3536
run: |
@@ -38,7 +39,7 @@ jobs:
3839
run: |
3940
python pip_build.py --install && cd integration_tests && pytest .
4041
multibackend:
41-
name: Test the code with Keras Core
42+
name: Test the code with Keras 3
4243
strategy:
4344
fail-fast: false
4445
matrix:

keras_nlp/models/task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def _check_for_loss_mismatch(self, loss):
7979
)
8080

8181
def compile(self, optimizer="rmsprop", loss=None, **kwargs):
82+
# Temporarily disable jit compilation on torch.
83+
if config.backend() == "torch":
84+
kwargs["jit_compile"] = False
8285
self._check_for_loss_mismatch(loss)
8386
super().compile(optimizer=optimizer, loss=loss, **kwargs)
8487

keras_nlp/tests/test_case.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def call(self, x):
138138
return self.layer(x)
139139

140140
model = TestModel(layer)
141-
model.compile(optimizer="sgd", loss="mse", jit_compile=True)
141+
# Temporarily disable jit compilation on torch backend.
142+
jit_compile = config.backend() != "torch"
143+
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
142144
model.fit(input_data, output_data, verbose=0)
143145

144146
if config.keras_3():

requirements-common.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Library deps.
2-
keras-core>=0.1.6
32
dm-tree
43
regex
54
rich
@@ -17,5 +16,3 @@ namex
1716
rouge-score
1817
sentencepiece
1918
tensorflow-datasets
20-
# Breakage fix.
21-
ml-dtypes==0.2.0

requirements-jax-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Tensorflow cpu-only version.
2-
tensorflow>=2.14.0
3-
tensorflow-text>=2.14.0
2+
tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0.
3+
tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0.
44

55
# Torch cpu-only version.
66
--extra-index-url https://download.pytorch.org/whl/cpu

requirements-tensorflow-cuda.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Tensorflow with cuda support.
2-
tensorflow[and-cuda]>=2.14.0
3-
tensorflow-text>=2.14.0
2+
--extra-index-url https://pypi.nvidia.com
3+
tf-nightly[and-cuda]==2.16.0.dev20231103 # Pin a working nightly until rc0.
4+
tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0.
45

56
# Torch cpu-only version.
67
--extra-index-url https://download.pytorch.org/whl/cpu

requirements-torch-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Tensorflow cpu-only version.
2-
tensorflow>=2.14.0
3-
tensorflow-text>=2.14.0
2+
tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0.
3+
tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0.
44

55
# Torch with cuda support.
66
--extra-index-url https://download.pytorch.org/whl/cu118

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Tensorflow.
2-
tensorflow>=2.14.0
3-
tensorflow-text>=2.14.0
2+
tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0.
3+
tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0.
44

55
# Torch.
66
--extra-index-url https://download.pytorch.org/whl/cpu

0 commit comments

Comments
 (0)