Skip to content

Commit

Permalink
make gpus=str in Trainer consistent with command line parsing of stri…
Browse files Browse the repository at this point in the history
…ng (#6388)

* string gpu input

* update docs

* deprecation warning

* Revert "update docs"

This reverts commit c5f3893.

* deprecation

* add changelog

* update parser

* update warning

* implement v1.5 behavior ahead of time

* formatting

* set accelerator in test to avoid different warning

* add warning

* remove todo warn

* Update pytorch_lightning/utilities/device_parser.py

Co-authored-by: Kaushik B <[email protected]>

* resolve flake8

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
4 people authored May 4, 2021
1 parent 2a20102 commit a6aa1a0
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `LightningModule.datamodule` getter and setter methods; access them through `Trainer.datamodule` instead ([#7168](https://github.com/PyTorchLightning/pytorch-lightning/pull/7168))


- Deprecated the use of `Trainer(gpus="i")` (string) for selecting the i-th GPU; from v1.5 this will set the number of GPUs instead of the index ([#6388](https://github.com/PyTorchLightning/pytorch-lightning/pull/6388))

### Removed


Expand Down
6 changes: 5 additions & 1 deletion docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,17 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
+---------------+-----------+---------------------+---------------------------------+
| "0" | str | [0] | GPU 0 |
+---------------+-----------+---------------------+---------------------------------+
| "3" | str | [3] | GPU 3 |
| "3" | str | [3] | GPU 3 (will change in v1.5) |
+---------------+-----------+---------------------+---------------------------------+
| "1, 3" | str | [1, 3] | GPUs 1 and 3 |
+---------------+-----------+---------------------+---------------------------------+
| "-1" | str | [0, 1, 2, ...] | all available GPUs |
+---------------+-----------+---------------------+---------------------------------+

.. warning::
The behavior for :code:`gpus="3"` (str) will change. Currently it selects the GPU with index 3, but will
select the first 3 GPUs from v1.5.

.. note::

When specifying number of gpus as an integer ``gpus=k``, setting the trainer flag
Expand Down
32 changes: 24 additions & 8 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from typing import Any, List, MutableSequence, Optional, Tuple, Union

import torch

from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version


def determine_root_gpu_device(gpus: List[int]) -> Optional[int]:
Expand Down Expand Up @@ -66,9 +68,12 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
if gpus is None or isinstance(gpus, int) and gpus == 0:
return None

if _compare_version("pytorch_lightning", operator.ge, "1.5") and isinstance(gpus, str) and gpus.strip() == "0":
# TODO: in v1.5 combine this with the above if statement
return None

# We know user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.

gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
Expand Down Expand Up @@ -107,13 +112,24 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if isinstance(s, str):
if s == '-1':
return -1
else:
return [int(x.strip()) for x in s.split(',') if len(x) > 0]
else:
if not isinstance(s, str):
return s
if s == '-1':
return -1
elif ',' in s:
return [int(x.strip()) for x in s.split(',') if len(x) > 0]
else:
num_gpus = int(s.strip())
if _compare_version("pytorch_lightning", operator.lt, "1.5"):
rank_zero_warn(
f"Parsing of the Trainer argument gpus='{s}' (string) will change in the future."
" In the current version of Lightning, this will select"
f" CUDA device with index {num_gpus}, but from v1.5 it will select gpus"
f" {list(range(num_gpus))} (same as gpus={s} (int)).",
DeprecationWarning,
)
return [num_gpus]
return num_gpus


def _sanitize_gpu_ids(gpus: List[int]) -> List[int]:
Expand Down
27 changes: 25 additions & 2 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""
import operator
import os
from typing import Any, Dict
from unittest import mock
Expand All @@ -26,8 +27,10 @@
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.imports import _compare_version
from tests.deprecated_api import no_deprecated_call
from tests.helpers import BoringModel, BoringDataModule
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.utils import no_warning_call


Expand All @@ -48,7 +51,7 @@ def test_v1_5_0_model_checkpoint_save_function():


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
def test_v1_5_0_wandb_unused_sync_step(_):
with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
WandbLogger(sync_step=True)

Expand Down Expand Up @@ -382,6 +385,26 @@ def test_v1_5_0_lighting_module_grad_norm(tmpdir):
model.grad_norm(2)


@pytest.mark.xfail(
condition=_compare_version("pytorch_lightning", operator.ge, "1.5"),
reason="parsing of string will change in v1.5",
)
@mock.patch('torch.cuda.device_count', return_value=4)
def test_v1_5_0_trainer_gpus_str_parsing(*_):
# TODO: when removing this, make sure docs in docs/advanced/multi-gpu.rst reflect the new
# behavior regarding GPU selection. Ping @awaelchli if unsure.
with pytest.deprecated_call(match=r"Parsing of the Trainer argument gpus='3' .* will change."):
Trainer(gpus="3", accelerator="ddp_spawn")

with pytest.deprecated_call(match=r"Parsing of the Trainer argument gpus='3' .* will change."):
gpus = device_parser.parse_gpu_ids("3")
assert gpus == [3]

with pytest.deprecated_call(match=r"Parsing of the Trainer argument gpus='0' .* will change."):
gpus = device_parser.parse_gpu_ids("0")
assert gpus == [0]


def test_v1_5_0_datamodule_setter():
model = BoringModel()
datamodule = BoringDataModule()
Expand Down
7 changes: 5 additions & 2 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from collections import namedtuple
from unittest.mock import patch

Expand All @@ -22,12 +23,14 @@
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel

PL_VERSION_LT_1_5 = _compare_version("pytorch_lightning", operator.lt, "1.5")
PRETEND_N_OF_GPUS = 16


Expand Down Expand Up @@ -171,8 +174,8 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu):
pytest.param([0], [0]),
pytest.param([1, 3], [1, 3]),
pytest.param((1, 3), [1, 3]),
pytest.param('0', [0]),
pytest.param('3', [3]),
pytest.param('0', None, marks=pytest.mark.skipif(PL_VERSION_LT_1_5, reason="available from v1.5")),
pytest.param('3', [0, 1, 2], marks=pytest.mark.skipif(PL_VERSION_LT_1_5, reason="available from v1.5")),
pytest.param('1, 3', [1, 3]),
pytest.param('2,', [2]),
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
Expand Down

0 comments on commit a6aa1a0

Please sign in to comment.