Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add teardown method to BaseProfiler. #6370

Merged
merged 10 commits into from
Mar 22, 2021
18 changes: 12 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


Expand All @@ -37,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370))


- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


Expand Down Expand Up @@ -120,6 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))


Expand Down Expand Up @@ -147,6 +153,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))


- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


Expand All @@ -170,12 +182,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))


- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

def teardown(self) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
pass

@contextmanager
def profile(self, action_name: str) -> None:
"""
Expand Down Expand Up @@ -214,11 +218,14 @@ def describe(self):
if self.output_file:
self.output_file.flush()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()


class AdvancedProfiler(BaseProfiler):
"""
Expand Down Expand Up @@ -286,7 +293,10 @@ def describe(self):
if self.output_file:
self.output_file.flush()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
5 changes: 4 additions & 1 deletion pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ def describe(self):
if self.output_file:
self.output_file.flush()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
else:
state = None

self.profiler.teardown()
self.teardown(stage=state)
model.teardown(stage=state)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,21 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
gpus=1,
)
trainer.fit(model)


@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_profiler_teardown(tmpdir, cls):
"""
This test checks if profiler teardown method is called when trainer is exiting.
"""
profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt"))

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profiler,
)
trainer.fit(model)

assert profiler.output_file.closed