Skip to content

Commit d514ad4

Browse files
committed
Update Torch version everywhere and TF version on GPU.
Replacement for keras-team#21704
1 parent 22524e1 commit d514ad4

File tree

6 files changed

+15
-55
lines changed

6 files changed

+15
-55
lines changed

.github/workflows/actions.yml

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -59,48 +59,9 @@ jobs:
5959
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
6060
pip install --upgrade flax>=0.11.1
6161
fi
62+
pip install --no-deps tf_keras==2.18.0
6263
pip uninstall -y keras keras-nightly
6364
pip install -e "." --progress-bar off --upgrade
64-
- name: Test applications with pytest
65-
if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}
66-
run: |
67-
pytest keras/src/applications --cov=keras/src/applications --cov-config=pyproject.toml
68-
coverage xml --include='keras/src/applications/*' -o apps-coverage.xml
69-
- name: Codecov keras.applications
70-
if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}
71-
uses: codecov/codecov-action@v5
72-
with:
73-
env_vars: PYTHON,KERAS_HOME
74-
flags: keras.applications,keras.applications-${{ matrix.backend }}
75-
files: apps-coverage.xml
76-
token: ${{ secrets.CODECOV_TOKEN }}
77-
fail_ci_if_error: false
78-
- name: Test integrations
79-
if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }}
80-
run: |
81-
python integration_tests/import_test.py
82-
python integration_tests/numerical_test.py
83-
- name: Test JAX-specific integrations
84-
if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }}
85-
run: |
86-
python integration_tests/jax_custom_fit_test.py
87-
- name: Test basic flow with NNX
88-
if: ${{ matrix.nnx_enabled == true }}
89-
env:
90-
KERAS_NNX_ENABLED: true
91-
run: |
92-
python integration_tests/import_test.py
93-
python integration_tests/basic_full_flow.py
94-
- name: Test TF-specific integrations
95-
if: ${{ matrix.backend == 'tensorflow'}}
96-
run: |
97-
python integration_tests/tf_distribute_training_test.py
98-
python integration_tests/tf_custom_fit_test.py
99-
- name: Test Torch-specific integrations
100-
if: ${{ matrix.backend == 'torch'}}
101-
run: |
102-
pytest integration_tests/torch_workflow_test.py
103-
python integration_tests/torch_custom_fit_test.py
10465
- name: Test with pytest
10566
if: ${{ matrix.nnx_enabled == false }}
10667
run: |
@@ -110,7 +71,7 @@ jobs:
11071
else
11172
IGNORE_ARGS=""
11273
fi
113-
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
74+
pytest keras/src/trainers/trainer_test.py keras/src/export/saved_model_test.py --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
11475
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
11576
- name: Codecov keras
11677
if: ${{ matrix.nnx_enabled == false }}

keras/src/export/onnx_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
7777
"backends."
7878
),
7979
)
80-
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
8180
@pytest.mark.skipif(
82-
testing.tensorflow_uses_gpu(), reason="Leads to core dumps on CI"
81+
testing.jax_uses_gpu()
82+
or testing.tensorflow_uses_gpu
83+
or testing.torch_uses_gpu(),
84+
reason="Fails on GPU",
8385
)
8486
class ExportONNXTest(testing.TestCase):
8587
@parameterized.named_parameters(

requirements-jax-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.1
2+
tensorflow-cpu~=2.20.0
33
tf2onnx
44

55
# Torch cpu-only version (needed for testing).
66
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
7+
torch==2.8.0
88

99
# Jax with cuda support.
1010
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

requirements-tensorflow-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Tensorflow with cuda support.
2-
tensorflow[and-cuda]~=2.18.1
2+
tensorflow[and-cuda]~=2.20.0
33
tf2onnx
44

55
# Torch cpu-only version (needed for testing).
66
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
7+
torch==2.8.0
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]

requirements-torch-cuda.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.1
2+
tensorflow-cpu~=2.20.0
33
tf2onnx
44

55
# Torch with cuda support.
66
# - torch is pinned to a version that is compatible with torch-xla.
77
--extra-index-url https://download.pytorch.org/whl/cu121
8-
torch==2.6.0
9-
torch-xla==2.6.0;sys_platform != 'darwin'
8+
torch==2.8.0
9+
torch-xla==2.8.1;sys_platform != 'darwin'
1010

1111
# Jax cpu-only version (needed for testing).
1212
jax[cpu]

requirements.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
# Tensorflow.
22
tensorflow-cpu~=2.18.1;sys_platform != 'darwin'
33
tensorflow~=2.18.1;sys_platform == 'darwin'
4-
tf_keras
54
tf2onnx
65

76
# Torch.
87
--extra-index-url https://download.pytorch.org/whl/cpu
9-
torch==2.6.0;sys_platform != 'darwin'
10-
torch==2.6.0;sys_platform == 'darwin'
8+
torch==2.7.1
119
torch-xla==2.6.0;sys_platform != 'darwin'
1210

1311
# Jax.
14-
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
15-
# Note that we test against the latest JAX on GPU.
1612
jax[cpu]==0.5.0
1713
flax
14+
1815
# Common deps.
1916
-r requirements-common.txt

0 commit comments

Comments
 (0)