Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into refactor/train-phase
Browse files Browse the repository at this point in the history
  • Loading branch information
s-rog authored Jan 13, 2021
2 parents eb25b15 + 7e4d6cb commit aea5882
Show file tree
Hide file tree
Showing 42 changed files with 174 additions and 177 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci_dockers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
python_version: [3.6]
pytorch_version: [1.3, 1.7]
pytorch_version: [1.4, 1.7]
steps:
- name: Checkout
uses: actions/checkout@v2
Expand Down Expand Up @@ -74,7 +74,7 @@ jobs:
- python_version: 3.7
pytorch_version: 1.6
- python_version: 3.6
pytorch_version: 1.3
pytorch_version: 1.4
steps:
- name: Checkout
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
# os: [ubuntu-20.04]
python-version: [3.7]
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8]
pytorch-version: [1.4, 1.5, 1.6, 1.7, 1.8]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down
17 changes: 0 additions & 17 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,6 @@ jobs:
open(fname, 'w').writelines(lines)
shell: python

# versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996)
- name: Adjust minimal for Python 3.8 and MacOS
if: matrix.requires == 'minimal' && (runner.os == 'macOS' || matrix.python-version == 3.8)
run : |
fname = 'requirements.txt'
req = open(fname).read().replace('torch>=1.3', 'torch>=1.4')
open(fname, 'w').write(req)
fname = 'requirements/examples.txt'
req = open(fname).read().replace('torchvision>=0.4.1', 'torchvision>=0.5.0')
open(fname, 'w').write(req)
fname = 'requirements/extra.txt'
req = open(fname).read().replace('torchtext>=0.3.1', 'torchtext>=0.5.0')
open(fname, 'w').write(req)
shell: python

- name: Set min. dependencies
if: matrix.requires == 'minimal'
run: |
Expand Down
6 changes: 1 addition & 5 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ jobs:
fail-fast: false
matrix:
python_version: [3.6, 3.7, 3.8]
pytorch_version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8]
exclude:
# excludes PT 1.3 as it is missing on pypi
- python_version: 3.8
pytorch_version: 1.3
pytorch_version: [1.4, 1.5, 1.6, 1.7, 1.8]

steps:
- name: Checkout
Expand Down
6 changes: 1 addition & 5 deletions .github/workflows/release-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ jobs:
fail-fast: false
matrix:
python_version: [3.6, 3.7, 3.8]
pytorch_version: [1.3, 1.4, 1.5, 1.6, 1.7]
exclude:
# excludes PT 1.3 as it is missing on pypi
- python_version: 3.8
pytorch_version: 1.3
pytorch_version: [1.4, 1.5, 1.6, 1.7]
steps:
- name: Checkout
uses: actions/checkout@v2
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))


- Set PyTorch 1.4 as min requirements, also for testing and examples `torchvision>=0.5` and `torchtext>=0.5` ([#5418](https://github.com/PyTorchLightning/pytorch-lightning/pull/5418))


- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446))


Expand Down
19 changes: 9 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,15 @@ Lightning can automatically export to ONNX or TorchScript for those cases.
## Continuous Integration
<center>

| System / PyTorch ver. | 1.3 (min. req.)* | 1.4 | 1.5 | 1.6 | 1.7 (latest) | 1.8 (nightly) |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Conda py3.7 [linux] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
| Linux py3.7 [GPUs**] | - | - | - | [![GPUs Status](http://104.154.220.231/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://104.154.220.231/PyTorchLightning/pytorch-lightning) | - | - |
| Linux py3.{6,7} [TPUs***] | - | - | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - |
| Linux py3.{6,7} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
| OSX py3.{6,7,8} | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
| Windows py3.{6,7,8} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |

- _\* `torch>=1.4` is the minimal pytorch version for Python 3.8_
| System / PyTorch ver. | 1.4 (min. req.)* | 1.5 | 1.6 | 1.7 (latest) | 1.8 (nightly) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Conda py3.7 [linux] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
| Linux py3.7 [GPUs**] | - | - | [![GPUs Status](http://104.154.220.231/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://104.154.220.231/PyTorchLightning/pytorch-lightning) | - | - |
| Linux py3.{6,7} [TPUs***] | - | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) |
| Linux py3.{6,7} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
| OSX py3.{6,7,8} | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |
| Windows py3.{6,7,8} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - |

- _\** tests run on two NVIDIA K80_
- _\*** tests run on Google GKE TPUv2/3_
- _TPU w/ py3.6/py3.7 means we support Colab and Kaggle env._
Expand Down
1 change: 0 additions & 1 deletion dockers/base-conda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.6 --build-arg PYTORCH_CHANNEL=pytorch
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.5 --build-arg PYTORCH_CHANNEL=pytorch
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.4 --build-arg PYTORCH_CHANNEL=pytorch
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.3 --build-arg PYTORCH_CHANNEL=pytorch

ARG CUDNN_VERSION=8
ARG CUDA_VERSION=10.2
Expand Down
1 change: 0 additions & 1 deletion dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.6 --build-arg CUDA_VERSION=10.2
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.5 --build-arg CUDA_VERSION=10.2
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.4 --build-arg CUDA_VERSION=10.1
# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.3 --build-arg CUDA_VERSION=10.1

ARG CUDNN_VERSION=8
ARG CUDA_VERSION=10.2
Expand Down
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies:
- python>=3.6
- pip>20.1
- numpy>=1.16.4
- pytorch>=1.3,<1.8
- pytorch>=1.4
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
Expand All @@ -38,10 +38,10 @@ dependencies:
- scikit-learn>=0.20.0
- matplotlib>=3.1.1
- omegaconf>=2.0.0
- torchtext>=0.3.1
- torchtext>=0.5

# Examples
- torchvision>=0.4.1,<0.9.0
- torchvision>=0.5

- pip:
- test-tube>=0.7.5
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ skip_glob = [
"tests/backends/*",
"tests/base/*",
"tests/callbacks/*",
"tests/checkpointing/*",
"tests/core/*",
"tests/loggers/*",
"tests/metrics/*",
Expand Down
37 changes: 24 additions & 13 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,21 @@ def select_accelerator(self):
# ----------------------------------
# choose an accelerator for the user
# ----------------------------------
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
use_slurm_ddp = (
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
and self.trainer.is_slurm_managing_tasks
)

# torchelastic or general non_slurm ddp
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
use_torchelastic_ddp = (
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed
)

use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn"
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu"
use_ddp_cpu_spawn = (
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
and self.trainer._device_type == DeviceType.CPU
)

use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks
Expand All @@ -204,8 +211,9 @@ def select_accelerator(self):

cluster_env = self._select_environment()

# TODO: clean-up this branching as most just select class and uses the very same arguments
# choose the appropriate accelerator backend
if self.trainer.use_ddp2:
if self.trainer._distrib_type == DistributedType.DDP2:
accelerator_backend = accelerators.DDP2Accelerator(
self.trainer,
cluster_env,
Expand Down Expand Up @@ -240,7 +248,7 @@ def select_accelerator(self):
self.trainer.plugin_connector.ddp_plugin
)

elif use_ddp_spawn:
elif self.trainer._distrib_type == DistributedType.DDP_SPAWN:
accelerator_backend = accelerators.DDPSpawnAccelerator(
self.trainer,
nprocs=self.trainer.num_processes,
Expand All @@ -263,16 +271,16 @@ def select_accelerator(self):
ddp_plugin=self.trainer.plugin_connector.ddp_plugin
)

elif self.trainer.use_dp:
elif self.trainer._distrib_type == DistributedType.DP:
accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env)

elif self.trainer.use_horovod:
elif self.trainer._distrib_type == DistributedType.HOROVOD:
accelerator_backend = accelerators.HorovodAccelerator(self.trainer, cluster_env)

elif self.trainer.use_single_gpu:
elif self.trainer._device_type == DeviceType.GPU and self.trainer.num_gpus == 1:
accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env)

elif self.trainer.use_tpu:
elif self.trainer._device_type == DeviceType.TPU:
accelerator_backend = accelerators.TPUAccelerator(self.trainer, cluster_env)

elif self.trainer.distributed_backend is None:
Expand Down Expand Up @@ -347,13 +355,16 @@ def set_distributed_mode(self):
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP):
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
if (self.trainer.num_nodes > 1 and self.trainer._distrib_type not in _ddp):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
)

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
rank_zero_info(
f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}'
)
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')

Expand All @@ -366,7 +377,7 @@ def _set_horovod_backend(self):

# Initialize Horovod to get rank / size info
hvd.init()
if self.trainer.on_gpu:
if self.trainer._device_type == DeviceType.GPU:
# Horovod assigns one local GPU per process
self.trainer.root_gpu = hvd.local_rank()

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.distributed import rank_zero_only

if _HOROVOD_AVAILABLE:
Expand All @@ -46,7 +46,7 @@ def setup(self, model):
# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

if torch.cuda.is_available() and self.trainer.on_gpu:
if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU:
# Horovod: pin GPU to local rank
assert self.trainer.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.trainer.root_gpu)
Expand Down Expand Up @@ -116,7 +116,7 @@ def train(self):
return results

def _step(self, model_step: Callable, args):
if self.trainer.on_gpu:
if self.trainer._device_type == DeviceType.GPU:
args[0] = self.batch_to_device(args[0], hvd.local_rank())

if self.trainer.amp_backend == AMPType.NATIVE:
Expand All @@ -141,7 +141,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
optimizer.synchronize()

def on_train_epoch_end(self, outputs):
hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1)
hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1)

def barrier(self, name: Optional[str] = None):
hvd.join()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Dict, List, Tuple

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_only, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict

Expand Down Expand Up @@ -104,7 +104,7 @@ def on_train_start(self, trainer, *args, **kwargs):
'Cannot use GPUStatsMonitor callback with Trainer that has no logger.'
)

if not trainer.on_gpu:
if trainer._device_type != DeviceType.GPU:
raise MisconfigurationException(
'You are using GPUStatsMonitor but are not running on GPU'
f' since gpus attribute in Trainer is set to {trainer.gpus}.'
Expand Down
Loading

0 comments on commit aea5882

Please sign in to comment.