Skip to content

Commit 2337b5b

Browse files
Update: PyTorch v2.6.0 (#1093)
- Add torch 2.6.0 to CI - Remove torch 2.2.2 - Update torch install instructions, as they no longer provide conda packages - Add test for new default of weights_only - Update pickle file test artifact (explained in #1092) - Update some comments - Conditionally install triton 3.1 for torch < 2.6
1 parent be93b77 commit 2337b5b

File tree

6 files changed

+42
-18
lines changed

6 files changed

+42
-18
lines changed

Diff for: .github/workflows/testing.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
fail-fast: false # don't cancel all jobs when one fails
2222
matrix:
2323
python_version: ['3.9', '3.10', '3.11', '3.12']
24-
torch_version: ['2.2.2+cpu', '2.3.1+cpu', '2.4.1+cpu', '2.5.1+cpu']
24+
torch_version: ['2.3.1+cpu', '2.4.1+cpu', '2.5.1+cpu', '2.6.0+cpu']
2525
os: [ubuntu-latest]
2626

2727
steps:
@@ -32,13 +32,18 @@ jobs:
3232
python-version: ${{ matrix.python_version }}
3333
- name: Install dependencies
3434
# TODO remove numpy version constraint once we no longer support PyTorch < 2.3
35+
# TODO remove triton 3.1 install if torch < 2.6 no longer supported, see #1093
3536
run: |
3637
python -m pip install --upgrade pip
3738
python -m pip install -r requirements-dev.txt
3839
python -m pip install -r requirements.txt
3940
python -m pip install --force-reinstall -U "numpy<2.0.0"
4041
python -m pip install pytest-pretty
4142
python -m pip install torch==${{ matrix.torch_version }} -f https://download.pytorch.org/whl/torch
43+
TORCH_VERSION_MAJOR_MINOR=$(python -c "import torch; v=torch.__version__.split('+')[0]; print('.'.join(v.split('.')[:2]))")
44+
if [[ $(echo "$TORCH_VERSION_MAJOR_MINOR < 2.6" | bc -l) -eq 1 ]]; then
45+
python -m pip install "triton==3.1"
46+
fi
4247
python -m pip list
4348
- name: Install skorch
4449
run: |

Diff for: README.rst

+5-8
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ To install skorch from source using conda, proceed as follows:
177177
178178
git clone https://github.com/skorch-dev/skorch.git
179179
cd skorch
180-
conda create -n skorch-env python=3.10
180+
conda create -n skorch-env python=3.12
181181
conda activate skorch-env
182-
conda install -c pytorch pytorch
182+
python -m pip install torch
183183
python -m pip install -r requirements.txt
184184
python -m pip install .
185185
@@ -189,9 +189,9 @@ If you want to help developing, run:
189189
190190
git clone https://github.com/skorch-dev/skorch.git
191191
cd skorch
192-
conda create -n skorch-env python=3.10
192+
conda create -n skorch-env python=3.12
193193
conda activate skorch-env
194-
conda install -c pytorch pytorch
194+
python -m pip install torch
195195
python -m pip install -r requirements.txt
196196
python -m pip install -r requirements-dev.txt
197197
python -m pip install -e .
@@ -239,10 +239,10 @@ instructions for PyTorch, visit the `PyTorch website
239239
<http://pytorch.org/>`__. skorch officially supports the last four
240240
minor PyTorch versions, which currently are:
241241

242-
- 2.2.2
243242
- 2.3.1
244243
- 2.4.1
245244
- 2.5.1
245+
- 2.6.0
246246

247247
However, that doesn't mean that older versions don't work, just that
248248
they aren't tested. Since skorch mostly relies on the stable part of
@@ -252,9 +252,6 @@ In general, running this to install PyTorch should work:
252252

253253
.. code:: bash
254254
255-
# using conda:
256-
conda install pytorch pytorch-cuda -c pytorch
257-
# using pip
258255
python -m pip install torch
259256
260257
==================

Diff for: docs/user/installation.rst

+1-4
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ instructions for PyTorch, visit the `PyTorch website
9898
<http://pytorch.org/>`__. skorch officially supports the last four
9999
minor PyTorch versions, which currently are:
100100

101-
- 2.2.2
102101
- 2.3.1
103102
- 2.4.1
104103
- 2.5.1
104+
- 2.6.0
105105

106106
However, that doesn't mean that older versions don't work, just that
107107
they aren't tested. Since skorch mostly relies on the stable part of
@@ -111,7 +111,4 @@ In general, running this to install PyTorch should work:
111111

112112
.. code:: bash
113113
114-
# using conda:
115-
conda install pytorch pytorch-cuda -c pytorch
116-
# using pip
117114
python -m pip install torch

Diff for: skorch/tests/net_cuda.pkl

-389 Bytes
Binary file not shown.

Diff for: skorch/tests/test_net.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -3075,9 +3075,8 @@ def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
30753075
):
30763076
# Same test as
30773077
# test_torch_load_kwargs_auto_weights_only_false_when_load_params but
3078-
# without monkeypatching get_default_torch_load_kwargs. There is no
3079-
# corresponding test for >= 2.6.0 since it's not clear yet if the switch
3080-
# will be made in that version.
3078+
# without monkeypatching get_default_torch_load_kwargs. The default is
3079+
# weights_only=False.
30813080
# See discussion in 1063.
30823081
from skorch._version import Version
30833082

@@ -3098,6 +3097,32 @@ def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
30983097
del call_kwargs['map_location'] # we're not interested in that
30993098
assert call_kwargs == expected_kwargs
31003099

3100+
def test_torch_load_kwargs_auto_weights_true_pytorch_ge_2_6(
3101+
self, net_cls, module_cls, monkeypatch, tmp_path
3102+
):
3103+
# Same test as
3104+
# test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6 but
3105+
# with weights_only=True, since it's the new default
3106+
# See discussion in 1063.
3107+
from skorch._version import Version
3108+
3109+
# TODO remove once torch 2.5.0 is no longer supported
3110+
if Version(torch.__version__) < Version('2.6.0'):
3111+
pytest.skip("Test only for torch >= 2.6.0")
3112+
3113+
net = net_cls(module_cls).initialize()
3114+
net.save_params(f_params=tmp_path / 'params.pkl')
3115+
state_dict = net.module_.state_dict()
3116+
expected_kwargs = {"weights_only": True}
3117+
3118+
mock_torch_load = Mock(return_value=state_dict)
3119+
monkeypatch.setattr(torch, "load", mock_torch_load)
3120+
net.load_params(f_params=tmp_path / 'params.pkl')
3121+
3122+
call_kwargs = mock_torch_load.call_args_list[0].kwargs
3123+
del call_kwargs['map_location'] # we're not interested in that
3124+
assert call_kwargs == expected_kwargs
3125+
31013126
def test_torch_load_kwargs_forwarded_to_torch_load_unpickle(
31023127
self, net_cls, module_cls, monkeypatch, tmp_path
31033128
):

Diff for: skorch/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,10 @@ def get_default_torch_load_kwargs():
775775
"""Returns the kwargs passed to torch.load that correspond to the current
776776
torch version.
777777
778-
The plan is to switch from weights_only=False to True in PyTorch version
779-
2.6.0, but depending on what happens, this may require updating.
778+
PyTorch switches from weights_only=False to True in version 2.6.0.
780779
781780
"""
781+
# TODO: Remove once PyTorch 2.5 is no longer supported
782782
version_torch = Version(torch.__version__)
783783
version_default_switch = Version('2.6.0')
784784
if version_torch >= version_default_switch:

0 commit comments

Comments
 (0)