Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Reduce memory leaks #8490

Merged
merged 41 commits into from
Jul 21, 2021
Merged
Changes from 4 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
aa89cc9
reduce memory leak
tchaton Jul 20, 2021
f57e21d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
341163b
update changelog
tchaton Jul 20, 2021
63eeaf0
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
f9ca9dc
Apply suggestions from code review
Borda Jul 20, 2021
1804975
resolve flake8
tchaton Jul 20, 2021
7a7b95f
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
7f06053
update on comments
tchaton Jul 20, 2021
b48e34d
resolve bug
tchaton Jul 20, 2021
eef89bc
update
tchaton Jul 20, 2021
13b335d
Undo whitespace changes
carmocca Jul 20, 2021
7fca3d8
remove bug
tchaton Jul 20, 2021
03e8faa
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
3fd83e3
resolve flake8
tchaton Jul 20, 2021
1d8b484
revert change
tchaton Jul 20, 2021
38ff815
update on comments
tchaton Jul 20, 2021
ef9a4cd
delete the ddp wrapper as it hold memory
tchaton Jul 20, 2021
9acb0c1
Merge branch 'master' into reduce_memory_leak
tchaton Jul 20, 2021
9ab40de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
000138e
resolve flake8
tchaton Jul 20, 2021
4e439f4
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
559436f
update on comments
tchaton Jul 20, 2021
8c9145d
update changelog
tchaton Jul 20, 2021
7c015cb
resolve test
tchaton Jul 20, 2021
70affbb
Update CHANGELOG
carmocca Jul 20, 2021
26bb10f
Refactor teardown
carmocca Jul 20, 2021
231b02b
Fix comment
carmocca Jul 20, 2021
0d1c365
Do it for non-gpu too
carmocca Jul 20, 2021
9077898
remove ref when the model is not a lightning_module
tchaton Jul 20, 2021
2ba1e9e
Fix import error
carmocca Jul 20, 2021
a8018df
Merge branch 'master' into reduce_memory_leak
tchaton Jul 20, 2021
666383c
move down
tchaton Jul 20, 2021
f16b8de
resolve bug
tchaton Jul 20, 2021
a915396
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
8c84391
resolve assignement
tchaton Jul 20, 2021
2d3223a
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
47c1ad3
update
tchaton Jul 21, 2021
9c347a2
move above
tchaton Jul 21, 2021
b26f98b
Fix device calls to support tpu training
kaushikb11 Jul 21, 2021
89a7033
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
kaushikb11 Jul 21, 2021
2719d03
Updat todo
kaushikb11 Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -493,6 +493,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


- Fixed reduce memory leaks on GPU by moving `optimizer_states` and `ResultCollection extras` to `cpu` ([#8490](https://github.com/PyTorchLightning/pytorch-lightning/pull/8490))
Borda marked this conversation as resolved.
Show resolved Hide resolved


## [1.3.8] - 2021-07-01

### Fixed
13 changes: 13 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -13,11 +13,13 @@
# limitations under the License.
import logging
import os
from typing import Any, Dict, Mapping

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_log = logging.getLogger(__name__)
@@ -52,3 +54,14 @@ def set_nvidia_flags(local_rank: int) -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def teardown(self) -> None:
super().teardown()

for optimizer in self.optimizers:
for k, v in optimizer.state.items():
if isinstance(v, Mapping):
optimizer.state[k] = {
n: value.cpu() if isinstance(value, torch.Tensor) else value
for n, value in v.items()
}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
@@ -308,6 +308,9 @@ def _training_step(
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()

# free memory
del step_kwargs
carmocca marked this conversation as resolved.
Show resolved Hide resolved

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

self._check_training_step_output(training_step_output)
Original file line number Diff line number Diff line change
@@ -605,6 +605,9 @@ def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[t
if self.minimize is not None:
self.minimize = self.minimize.to(*args, **kwargs)
self._batch_size = self._batch_size.to(*args, **kwargs)

self['_extra'] = apply_to_collection(self.extra, (torch.Tensor), to_, *args, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if 'device' in kwargs:
self.device = kwargs['device']
return self
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1178,10 +1178,13 @@ def _call_teardown_hook(self, model: 'pl.LightningModule') -> None:

if self.datamodule is not None:
self.datamodule.teardown(stage=fn)

self.profiler.teardown(stage=fn)
self.teardown(stage=fn)
model.teardown(stage=fn)

self._active_loop.teardown()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

model._current_fx_name = None
model._current_dataloader_idx = None
# these could have become stale if metrics are defined in `setup`
52 changes: 50 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@
import pickle
import sys
from argparse import Namespace
from copy import deepcopy
from copy import _deepcopy_dispatch, deepcopy
from enum import Enum
from pathlib import Path
from unittest.mock import ANY, call, patch

@@ -36,9 +37,10 @@
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
@@ -1969,3 +1971,49 @@ def training_step(self, batch, batch_idx):
# simulate random failure in training_step on rank 0
with pytest.raises(DeadlockDetectedException, match="CustomException"):
trainer.fit(model)


@RunIf(min_gpu=1)
def test_multiple_trainer_constant_memory_allocated(tmpdir):
"""
This tests ensures calling the trainer several times doesn't increase memory allocated.
"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self._example_input_array = torch.zeros((2, 32))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@property
def example_input_array(self):
return self._example_input_array
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
carmocca marked this conversation as resolved.
Show resolved Hide resolved

initial = torch.cuda.memory_allocated(0)

model = TestModel()
trainer_kwargs = dict(default_root_dir=tmpdir, fast_dev_run=True, gpus=1, accelerator="ddp")
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)

assert model._example_input_array.device == torch.device("cpu")
assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu")
assert trainer.optimizers[0].state
tchaton marked this conversation as resolved.
Show resolved Hide resolved

before = torch.cuda.memory_allocated(0)
deepcopy(trainer)
after = torch.cuda.memory_allocated(0)
torch.cuda.empty_cache()
assert before == after

trainer_2 = Trainer(**trainer_kwargs)
trainer_2.fit(model)
after_2 = torch.cuda.memory_allocated(0)

# todo: (tchaton) Still some memory leaks, could not find the source.
assert initial + 2048 == before + 1024 == after_2
carmocca marked this conversation as resolved.
Show resolved Hide resolved