From 85a2ad19435c26159443daeff54bc925ac312af1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20V=C3=B6lgyes?= Date: Fri, 26 Feb 2021 10:17:53 +0100 Subject: [PATCH 01/10] Fix for incorrect detach/cpu calls (#6214) --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 530001e0be52d..68453811da203 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -736,9 +736,9 @@ def run_evaluation(self, max_batches=None, on_epoch=False): def track_output_for_epoch_end(self, outputs, output): if output is not None: if isinstance(output, Result): - output.detach() + output = output.detach() if self.move_metrics_to_cpu: - output.cpu() + output = output.cpu() elif isinstance(output, dict): output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu) elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu: From 1fe84f4c91e2f5084c6c93b0aac957e00dc83e31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20V=C3=B6lgyes?= Date: Fri, 26 Feb 2021 10:55:18 +0100 Subject: [PATCH 02/10] Fix incorrect use of detach(), to(), and cpu(), #6214 --- .../connectors/logger_connector/epoch_result_store.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index c435204107775..a547144c8a6f3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -281,11 +281,11 @@ def cache_result(self) -> None: # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) - hook_result.detach() + hook_result = hook_result.detach() if self.trainer.move_metrics_to_cpu: - hook_result.cpu() + hook_result = hook_result.cpu() elif self.trainer._distrib_type == DistributedType.DP: - hook_result.to(torch.device("cuda", self.trainer.root_gpu)) + hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu)) self._internals[fx_name].append(hook_result, info) From a77ef95542cf3984cc68cbc929b963db6c946224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20V=C3=B6lgyes?= Date: Fri, 26 Feb 2021 10:58:53 +0100 Subject: [PATCH 03/10] Fix incorrect use of detach() and cpu(), #6214 --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 57ad3f6b06d36..f4c6878970345 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -261,7 +261,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) if is_result_obj: - training_step_output.detach() + training_step_output = training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach() @@ -395,9 +395,9 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch): # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) - training_step_output_for_epoch_end.detach() + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: - training_step_output_for_epoch_end.cpu() + training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() # what flows back into the system training_step_output = result From 48cfb7f5c93a6048daae00428cccbdb5949659cd Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Feb 2021 12:12:56 +0000 Subject: [PATCH 04/10] update pr --- pytorch_lightning/core/step_result.py | 3 +++ .../connectors/logger_connector/epoch_result_store.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 974974b032bec..dc51caf8e1f02 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -420,16 +420,19 @@ def detach(self): for k, v in self.items(): if isinstance(v, torch.Tensor): self.__setitem__(k, v.detach()) + return self def to(self, *args, **kwargs): """Move all self attributes to the given device.""" for k, v in self.items(): if isinstance(v, torch.Tensor): self.__setitem__(k, v.to(*args, **kwargs)) + return self def cpu(self): """Move all self attributes to CPU.""" self.to(torch.device("cpu")) + return self def __repr__(self): self_copy = self.copy() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index a547144c8a6f3..c435204107775 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -281,11 +281,11 @@ def cache_result(self) -> None: # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) - hook_result = hook_result.detach() + hook_result.detach() if self.trainer.move_metrics_to_cpu: - hook_result = hook_result.cpu() + hook_result.cpu() elif self.trainer._distrib_type == DistributedType.DP: - hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu)) + hook_result.to(torch.device("cuda", self.trainer.root_gpu)) self._internals[fx_name].append(hook_result, info) From c170e9d4baa95ad2a4525af5d35d5ebaf9d7db49 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Feb 2021 12:14:25 +0000 Subject: [PATCH 05/10] add typing --- pytorch_lightning/core/step_result.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index dc51caf8e1f02..8e99aa16085c7 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -416,20 +416,20 @@ def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_i return result - def detach(self): + def detach(self) -> 'Result': for k, v in self.items(): if isinstance(v, torch.Tensor): self.__setitem__(k, v.detach()) return self - def to(self, *args, **kwargs): + def to(self, *args, **kwargs) -> 'Result': """Move all self attributes to the given device.""" for k, v in self.items(): if isinstance(v, torch.Tensor): self.__setitem__(k, v.to(*args, **kwargs)) return self - def cpu(self): + def cpu(self) -> 'Result': """Move all self attributes to CPU.""" self.to(torch.device("cpu")) return self From f069c076f6dbffdab5d370a2754ac13de26eedc1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 26 Feb 2021 13:25:45 +0100 Subject: [PATCH 06/10] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1d4b1d6c983..6b642999dcd2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) +- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216)) + + ## [1.2.1] - 2021-02-23 ### Fixed From 393dae1051c37bb0e750ed2091471f2f8571d314 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 26 Feb 2021 13:32:49 +0100 Subject: [PATCH 07/10] more... --- pytorch_lightning/accelerators/gpu.py | 2 +- pytorch_lightning/core/step_result.py | 3 +-- pytorch_lightning/plugins/training_type/ddp.py | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 2 +- pytorch_lightning/plugins/training_type/horovod.py | 2 +- pytorch_lightning/plugins/training_type/single_device.py | 2 +- pytorch_lightning/plugins/training_type/single_tpu.py | 2 +- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- .../connectors/logger_connector/epoch_result_store.py | 6 +++--- tests/helpers/pipelines.py | 2 +- tests/models/test_restore.py | 4 ++-- tests/overrides/test_data_parallel.py | 3 +-- tests/plugins/test_deepspeed_plugin.py | 4 ++-- 14 files changed, 18 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 785a0cc8ddea8..721cfb4400dfb 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -32,7 +32,7 @@ def on_train_start(self) -> None: def on_train_end(self) -> None: # clean up memory - self.model.cpu() + self.model = self.model.cpu() with torch.cuda.device(self.root_device): torch.cuda.empty_cache() diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 8e99aa16085c7..f8d7a2ffe3a23 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -431,8 +431,7 @@ def to(self, *args, **kwargs) -> 'Result': def cpu(self) -> 'Result': """Move all self attributes to CPU.""" - self.to(torch.device("cpu")) - return self + return self.to(torch.device("cpu")) def __repr__(self): self_copy = self.copy() diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 80161d6e59b6b..a0acfb31c57fd 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -276,7 +276,7 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti def model_to_device(self): if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) - self.model.to(self.root_device) + self.model = self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ca25a6d8bc382..be9067e2b7bea 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -249,7 +249,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: def model_to_device(self): if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) - self.model.to(self.root_device) + self.model = self.model.to(self.root_device) def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index c2b16303e5d4e..4761540bea469 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -28,7 +28,7 @@ def __init__(self, parallel_devices: Optional[List[torch.device]]): def setup(self, model): # model needs to be moved to the device before it is wrapped - model.to(self.root_device) + model = model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) def reduce(self, tensor, *args, **kwargs): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index e940cb1d7229b..785c4eb9e0bb1 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -125,7 +125,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: def model_to_device(self): if self.on_gpu: torch.cuda.set_device(self.root_device) - self.model.to(self.root_device) + self.model = self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 5bf0597ed7f18..248838880b65b 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -42,7 +42,7 @@ def model_to_device(self) -> None: if self.on_gpu: torch.cuda.set_device(self.root_device) - self._model.to(self.root_device) + self._model = self._model.to(self.root_device) def connect(self, model: torch.nn.Module) -> torch.nn.Module: self._model = model diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 3ddfd98128787..2c7083b786a45 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -32,7 +32,7 @@ def connect(self, model: torch.nn.Module) -> torch.nn.Module: return self._model def model_to_device(self) -> None: - self._model.to(self.root_device) + self._model = self._model.to(self.root_device) def pre_dispatch(self) -> None: if isinstance(self.device, int): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 692a4426a6ad6..595aeb2ffa6b1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -109,7 +109,7 @@ def __save_end_of_training_weights(self, model: LightningModule) -> None: self.save_spawn_weights(model) def model_to_device(self) -> None: - self._model.to(xm.xla_device()) + self._model = self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: rendezvous(f"pl.Trainer.{name}") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index c435204107775..a547144c8a6f3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -281,11 +281,11 @@ def cache_result(self) -> None: # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) - hook_result.detach() + hook_result = hook_result.detach() if self.trainer.move_metrics_to_cpu: - hook_result.cpu() + hook_result = hook_result.cpu() elif self.trainer._distrib_type == DistributedType.DP: - hook_result.to(torch.device("cuda", self.trainer.root_gpu)) + hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu)) self._internals[fx_name].append(hook_result, info) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 403bcdfee8c1d..0f1ddf3fd718e 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -98,7 +98,7 @@ def run_model_test( @torch.no_grad() def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50): # run prediction on 1 batch - trained_model.cpu() + trained_model = trained_model.cpu() trained_model.eval() batch = next(iter(dataloader)) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 7d6c104abbd57..231ec8d2795da 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -252,7 +252,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - pretrained_model.cpu() + pretrained_model = pretrained_model.cpu() dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): @@ -300,7 +300,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - pretrained_model.cpu() + pretrained_model = pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 90bb6fac88457..128bb224110d0 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -145,8 +145,7 @@ def training_step(self, batch, batch_idx): output.update({"python scalar": 12.3}) return output - model = TestModel() - model.to(device) + model = TestModel().to(device) model.trainer = Mock() model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).to(device) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 903855fd2c0eb..1c2797e32ee16 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -26,7 +26,7 @@ def test_deepspeed_lightning_module(tmpdir): assert module.dtype == torch.half assert model.dtype == torch.half - module.to(torch.double) + module = module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double @@ -49,7 +49,7 @@ def test_deepspeed_lightning_module_precision(tmpdir): assert out.dtype == torch.half - module.to(torch.double) + module = module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double From 0e9f340f2fc80497aab1842d0977b59b7f2b5c8e Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Mar 2021 11:36:31 +0000 Subject: [PATCH 08/10] revert on module --- pytorch_lightning/plugins/training_type/ddp.py | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/horovod.py | 2 +- pytorch_lightning/plugins/training_type/single_device.py | 2 +- pytorch_lightning/plugins/training_type/single_tpu.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a0acfb31c57fd..80161d6e59b6b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -276,7 +276,7 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti def model_to_device(self): if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) - self.model = self.model.to(self.root_device) + self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index be9067e2b7bea..ca25a6d8bc382 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -249,7 +249,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: def model_to_device(self): if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) - self.model = self.model.to(self.root_device) + self.model.to(self.root_device) def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 785c4eb9e0bb1..e940cb1d7229b 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -125,7 +125,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: def model_to_device(self): if self.on_gpu: torch.cuda.set_device(self.root_device) - self.model = self.model.to(self.root_device) + self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 61893647cff55..d11ae87bed660 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -55,7 +55,7 @@ def model_to_device(self) -> None: if self.on_gpu: torch.cuda.set_device(self.root_device) - self._model = self._model.to(self.root_device) + self._model.to(self.root_device) def connect(self, model: torch.nn.Module) -> torch.nn.Module: self._model = model diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index fb64d685cb7e7..d3cbd0d6b5d79 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -45,7 +45,7 @@ def connect(self, model: torch.nn.Module) -> torch.nn.Module: return self._model def model_to_device(self) -> None: - self._model = self._model.to(self.root_device) + self._model.to(self.root_device) def pre_dispatch(self) -> None: if isinstance(self.device, int): diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 1c2797e32ee16..903855fd2c0eb 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -26,7 +26,7 @@ def test_deepspeed_lightning_module(tmpdir): assert module.dtype == torch.half assert model.dtype == torch.half - module = module.to(torch.double) + module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double @@ -49,7 +49,7 @@ def test_deepspeed_lightning_module_precision(tmpdir): assert out.dtype == torch.half - module = module.to(torch.double) + module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double From aeb346a15cfbd90e4ce586922c93d9737363c549 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Mar 2021 11:38:41 +0000 Subject: [PATCH 09/10] update on comments --- pytorch_lightning/plugins/training_type/dp.py | 2 +- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- tests/helpers/pipelines.py | 2 +- tests/models/test_restore.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 4761540bea469..c2b16303e5d4e 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -28,7 +28,7 @@ def __init__(self, parallel_devices: Optional[List[torch.device]]): def setup(self, model): # model needs to be moved to the device before it is wrapped - model = model.to(self.root_device) + model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) def reduce(self, tensor, *args, **kwargs): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 43755df112d13..4232cba485414 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -122,7 +122,7 @@ def __save_end_of_training_weights(self, model: LightningModule) -> None: self.save_spawn_weights(model) def model_to_device(self) -> None: - self._model = self._model.to(xm.xla_device()) + self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: rendezvous(f"pl.Trainer.{name}") diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 0f1ddf3fd718e..403bcdfee8c1d 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -98,7 +98,7 @@ def run_model_test( @torch.no_grad() def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50): # run prediction on 1 batch - trained_model = trained_model.cpu() + trained_model.cpu() trained_model.eval() batch = next(iter(dataloader)) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 231ec8d2795da..7d6c104abbd57 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -252,7 +252,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - pretrained_model = pretrained_model.cpu() + pretrained_model.cpu() dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): @@ -300,7 +300,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - pretrained_model = pretrained_model.cpu() + pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): From d6be9832ea31db6a578a8fbfd2a3c7d403e57616 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Mar 2021 11:39:38 +0000 Subject: [PATCH 10/10] revert changes on model --- pytorch_lightning/accelerators/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 721cfb4400dfb..785a0cc8ddea8 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -32,7 +32,7 @@ def on_train_start(self) -> None: def on_train_end(self) -> None: # clean up memory - self.model = self.model.cpu() + self.model.cpu() with torch.cuda.device(self.root_device): torch.cuda.empty_cache()