Skip to content

Commit

Permalink
Add teardown method to BaseProfiler. (#6370)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: ananthsub <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2021
1 parent 58c9fa7 commit e2e1de0
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,4 @@ tags
data
MNIST
runs
*traces*
18 changes: 12 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


Expand All @@ -37,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370))


- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


Expand Down Expand Up @@ -120,6 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))


Expand Down Expand Up @@ -147,6 +153,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))


- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


Expand All @@ -170,12 +182,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))


- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
20 changes: 14 additions & 6 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

def teardown(self) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
pass

@contextmanager
def profile(self, action_name: str) -> None:
"""
Expand Down Expand Up @@ -211,14 +215,16 @@ def log_row(action, mean, total):
def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()
self.teardown()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()


class AdvancedProfiler(BaseProfiler):
"""
Expand Down Expand Up @@ -283,10 +289,12 @@ def summary(self) -> str:
def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()
self.teardown()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
8 changes: 5 additions & 3 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,12 @@ def summary(self) -> str:
def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()
self.teardown()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
else:
state = None

self.profiler.teardown()
self.teardown(stage=state)
model.teardown(stage=state)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def on_train_end(self):
self.trainer.logger.finalize("success")

# summarize profile results
# todo (tchaton) All ranks should call describe.
if self.trainer.global_rank == 0:
self.trainer.profiler.describe()

Expand Down
22 changes: 20 additions & 2 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename):
assert profiler.summary() is None
assert set(profiler.profiled_actions.keys()) == set()

if use_output_filename:
profiler.describe()
# todo (tchaton) add support for all ranks
if use_output_filename and os.getenv("LOCAL_RANK") == "0":
data = Path(profiler.output_fname).read_text()
assert len(data) > 0

Expand Down Expand Up @@ -316,3 +316,21 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
gpus=1,
)
trainer.fit(model)


@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_profiler_teardown(tmpdir, cls):
"""
This test checks if profiler teardown method is called when trainer is exiting.
"""
profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt"))

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profiler,
)
trainer.fit(model)

assert profiler.output_file.closed

0 comments on commit e2e1de0

Please sign in to comment.