diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 7d4a7d4ed0..4d6a5e7dfb 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -7,7 +7,7 @@ on: types: [created] jobs: build: - name: Test the code with tf.keras + name: Test the code with Keras 2 runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -29,7 +29,8 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies run: | - pip install -r requirements.txt --progress-bar off + pip install -r requirements-common.txt --progress-bar off + pip install tensorflow-text==2.14 tensorflow==2.14 keras-core pip install --no-deps -e "." --progress-bar off - name: Test with pytest run: | @@ -38,7 +39,7 @@ jobs: run: | python pip_build.py --install && cd integration_tests && pytest . multibackend: - name: Test the code with Keras Core + name: Test the code with Keras 3 strategy: fail-fast: false matrix: diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index d4c6180405..2c1d0f40f1 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -79,6 +79,9 @@ def _check_for_loss_mismatch(self, loss): ) def compile(self, optimizer="rmsprop", loss=None, **kwargs): + # Temporarily disable jit compilation on torch. + if config.backend() == "torch": + kwargs["jit_compile"] = False self._check_for_loss_mismatch(loss) super().compile(optimizer=optimizer, loss=loss, **kwargs) diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 6fe72ed497..fefa7a3a0f 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -143,7 +143,9 @@ def call(self, x): return self.layer(x) model = TestModel(layer) - model.compile(optimizer="sgd", loss="mse", jit_compile=True) + # Temporarily disable jit compilation on torch backend. + jit_compile = config.backend() != "torch" + model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile) model.fit(input_data, output_data, verbose=0) if config.multi_backend(): diff --git a/requirements-common.txt b/requirements-common.txt index 44661e315a..5c9710de4b 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,5 +1,4 @@ # Library deps. -keras-core>=0.1.6 dm-tree regex rich @@ -17,5 +16,3 @@ namex rouge-score sentencepiece tensorflow-datasets -# Breakage fix. -ml-dtypes==0.2.0 diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index bb115b14f8..f424af9cb7 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tensorflow>=2.14.0 -tensorflow-text>=2.14.0 +tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0. # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 4b2cf167ea..98c2746474 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,7 @@ # Tensorflow with cuda support. -tensorflow[and-cuda]>=2.14.0 -tensorflow-text>=2.14.0 +--extra-index-url https://pypi.nvidia.com +tf-nightly[and-cuda]==2.16.0.dev20231103 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0. # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 14e94dd862..eb147a3d38 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tensorflow>=2.14.0 -tensorflow-text>=2.14.0 +tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0. # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu118 diff --git a/requirements.txt b/requirements.txt index aa289402fd..a17dc717a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. -tensorflow>=2.14.0 -tensorflow-text>=2.14.0 +tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20231103 # Pin a working nightly until rc0. # Torch. --extra-index-url https://download.pytorch.org/whl/cpu