Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add teardown method to BaseProfiler. #6370

Merged
merged 10 commits into from
Mar 22, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))

- Added teardown method to BaseProfiler to enable subclasses define post-profiling steps outside of __del__ method.

- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))

Expand Down
Binary file added a.out
Binary file not shown.
14 changes: 12 additions & 2 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

def teardown(self) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
pass

@contextmanager
def profile(self, action_name: str) -> None:
"""
Expand Down Expand Up @@ -214,11 +218,14 @@ def describe(self):
if self.output_file:
self.output_file.flush()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()


class AdvancedProfiler(BaseProfiler):
"""
Expand Down Expand Up @@ -286,7 +293,10 @@ def describe(self):
if self.output_file:
self.output_file.flush()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
5 changes: 4 additions & 1 deletion pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ def describe(self):
if self.output_file:
self.output_file.flush()

def __del__(self):
def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def fit(
self.call_hook('on_fit_end')

# teardown
self.profiler.teardown()
self.call_teardown_hook(model)

if self.state != TrainerState.INTERRUPTED:
Expand Down
40 changes: 40 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,46 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
trainer.fit(model)


def test_profiler_teardown(tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
This test checks if profiler teardown method is called when
trainer is exiting.
"""
profilerSimple = SimpleProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
profilerAdvanced = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
profilerPytorch = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0)

carmocca marked this conversation as resolved.
Show resolved Hide resolved
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profilerSimple,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
trainer.fit(model)

assert profilerSimple.output_file.closed
carmocca marked this conversation as resolved.
Show resolved Hide resolved

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profilerAdvanced,
)
trainer.fit(model)

assert profilerAdvanced.output_file.closed
carmocca marked this conversation as resolved.
Show resolved Hide resolved

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profilerPytorch,
)
trainer.fit(model)

assert profilerPytorch.output_file.closed
carmocca marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],
Expand Down