Skip to content

Commit 64ee9bf

Browse files
committed
Update Torch version everywhere and TF version on GPU.
Replacement for keras-team#21704
1 parent 22524e1 commit 64ee9bf

File tree

6 files changed

+15
-43
lines changed

6 files changed

+15
-43
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,32 +75,6 @@ jobs:
7575
files: apps-coverage.xml
7676
token: ${{ secrets.CODECOV_TOKEN }}
7777
fail_ci_if_error: false
78-
- name: Test integrations
79-
if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }}
80-
run: |
81-
python integration_tests/import_test.py
82-
python integration_tests/numerical_test.py
83-
- name: Test JAX-specific integrations
84-
if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }}
85-
run: |
86-
python integration_tests/jax_custom_fit_test.py
87-
- name: Test basic flow with NNX
88-
if: ${{ matrix.nnx_enabled == true }}
89-
env:
90-
KERAS_NNX_ENABLED: true
91-
run: |
92-
python integration_tests/import_test.py
93-
python integration_tests/basic_full_flow.py
94-
- name: Test TF-specific integrations
95-
if: ${{ matrix.backend == 'tensorflow'}}
96-
run: |
97-
python integration_tests/tf_distribute_training_test.py
98-
python integration_tests/tf_custom_fit_test.py
99-
- name: Test Torch-specific integrations
100-
if: ${{ matrix.backend == 'torch'}}
101-
run: |
102-
pytest integration_tests/torch_workflow_test.py
103-
python integration_tests/torch_custom_fit_test.py
10478
- name: Test with pytest
10579
if: ${{ matrix.nnx_enabled == false }}
10680
run: |
@@ -110,7 +84,7 @@ jobs:
11084
else
11185
IGNORE_ARGS=""
11286
fi
113-
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
87+
pytest keras/src/export/saved_model_test.py --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
11488
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
11589
- name: Codecov keras
11690
if: ${{ matrix.nnx_enabled == false }}

keras/src/export/onnx_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
7777
"backends."
7878
),
7979
)
80-
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
8180
@pytest.mark.skipif(
82-
testing.tensorflow_uses_gpu(), reason="Leads to core dumps on CI"
81+
testing.jax_uses_gpu()
82+
or testing.tensorflow_uses_gpu
83+
or testing.torch_uses_gpu(),
84+
reason="Fails on GPU",
8385
)
8486
class ExportONNXTest(testing.TestCase):
8587
@parameterized.named_parameters(

requirements-jax-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.1
2+
tensorflow-cpu~=2.20.0
33
tf2onnx
44

55
# Torch cpu-only version (needed for testing).
66
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
7+
torch==2.8.0
88

99
# Jax with cuda support.
1010
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

requirements-tensorflow-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Tensorflow with cuda support.
2-
tensorflow[and-cuda]~=2.18.1
2+
tensorflow[and-cuda]~=2.20.0
33
tf2onnx
44

55
# Torch cpu-only version (needed for testing).
66
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
7+
torch==2.8.0
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]

requirements-torch-cuda.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.1
2+
tensorflow-cpu~=2.20.0
33
tf2onnx
44

55
# Torch with cuda support.
66
# - torch is pinned to a version that is compatible with torch-xla.
77
--extra-index-url https://download.pytorch.org/whl/cu121
8-
torch==2.6.0
9-
torch-xla==2.6.0;sys_platform != 'darwin'
8+
torch==2.8.0
9+
torch-xla==2.8.1;sys_platform != 'darwin'
1010

1111
# Jax cpu-only version (needed for testing).
1212
jax[cpu]

requirements.txt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
# Tensorflow.
22
tensorflow-cpu~=2.18.1;sys_platform != 'darwin'
33
tensorflow~=2.18.1;sys_platform == 'darwin'
4-
tf_keras
54
tf2onnx
65

76
# Torch.
87
--extra-index-url https://download.pytorch.org/whl/cpu
9-
torch==2.6.0;sys_platform != 'darwin'
10-
torch==2.6.0;sys_platform == 'darwin'
11-
torch-xla==2.6.0;sys_platform != 'darwin'
8+
torch==2.7.1
9+
torch-xla==2.7.0;sys_platform != 'darwin'
1210

1311
# Jax.
14-
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
15-
# Note that we test against the latest JAX on GPU.
1612
jax[cpu]==0.5.0
17-
flax
13+
1814
# Common deps.
1915
-r requirements-common.txt

0 commit comments

Comments
 (0)