Skip to content

Commit

Permalink
refactor and add types
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Mar 27, 2020
1 parent 61177cd commit b4c552f
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 131 deletions.
123 changes: 123 additions & 0 deletions pytorch_lightning/debugging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Profiling your training run can help you understand if there are any bottlenecks in your code.
Built-in checks
----------------
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
- on_epoch_start
- on_epoch_end
- on_batch_start
- tbptt_split_batch
- model_forward
- model_backward
- on_after_backward
- optimizer_step
- on_batch_end
- training_step_end
- on_training_end
Enable simple profiling
-------------------------
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
.. code-block:: python
trainer = Trainer(..., profiler=True)
The profiler's results will be printed at the completion of a training `fit()`.
.. code-block:: python
Profiler Report
Action | Mean duration (s) | Total time (s)
-----------------------------------------------------------------
on_epoch_start | 5.993e-06 | 5.993e-06
get_train_batch | 0.0087412 | 16.398
on_batch_start | 5.0865e-06 | 0.0095372
model_forward | 0.0017818 | 3.3408
model_backward | 0.0018283 | 3.4282
on_after_backward | 4.2862e-06 | 0.0080366
optimizer_step | 0.0011072 | 2.0759
on_batch_end | 4.5202e-06 | 0.0084753
on_epoch_end | 3.919e-06 | 3.919e-06
on_train_end | 5.449e-06 | 5.449e-06
Advanced Profiling
--------------------
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile
.. code-block:: python
profiler = AdvancedProfiler()
trainer = Trainer(..., profiler=profiler)
The profiler's results will be printed at the completion of a training `fit()`. This profiler
report can be quite long, so you can also specify an `output_filename` to save the report instead
of logging it to the output in your terminal. The output below shows the profiling for the action
`get_train_batch`.
.. code-block:: python
Profiler Report
Profile stats for: get_train_batch
4869394 function calls (4863767 primitive calls) in 18.893 seconds
Ordered by: cumulative time
List reduced from 76 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next}
1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__)
1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data)
1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch)
1875 0.084 0.000 18.290 0.010 fetch.py:44(<listcomp>)
60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__)
60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__)
60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__)
60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor)
60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__)
You can also reference this profiler in your LightningModule to profile specific actions of interest.
If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler`
which will allow you to skip profiling without having to make any code changes. Each profiler has a
method `profile()` which returns a context handler. Simply pass in the name of your action that you want
to track and the profiler will record performance for code executed within this context.
.. code-block:: python
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
class MyModel(LightningModule):
def __init__(self, hparams, profiler=None):
self.hparams = hparams
self.profiler = profiler or PassThroughProfiler()
def custom_processing_step(self, data):
with profiler.profile('my_custom_action'):
# custom processing step
return data
profiler = Profiler()
model = MyModel(hparams, profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)
"""

from pytorch_lightning.debugging.profilers import BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler

__all__ = [
'BaseProfiler',
'Profiler',
'AdvancedProfiler',
'PassThroughProfiler',
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ class BaseProfiler(ABC):
"""

@abstractmethod
def start(self, action_name):
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""

@abstractmethod
def stop(self, action_name):
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

@contextmanager
def profile(self, action_name):
def profile(self, action_name: str) -> None:
"""
Yields a context manager to encapsulate the scope of a profiled action.
Expand Down Expand Up @@ -82,18 +82,23 @@ class Profiler(BaseProfiler):
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self):
def __init__(self, output_filename: str = None):
"""
:param output_filename (str): optionally save profile results to file instead of printing
to std out when training is finished.
"""
self.current_actions = {}
self.output_filename = output_filename
self.recorded_durations = defaultdict(list)

def start(self, action_name):
def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
f"Attempted to start {action_name} which has already started."
)
self.current_actions[action_name] = time.monotonic()

def stop(self, action_name):
def stop(self, action_name: str) -> None:
end_time = time.monotonic()
if action_name not in self.current_actions:
raise ValueError(
Expand All @@ -103,7 +108,7 @@ def stop(self, action_name):
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)

def describe(self):
def describe(self) -> None:
output_string = "\n\nProfiler Report\n"

def log_row(action, mean, total):
Expand Down Expand Up @@ -138,12 +143,12 @@ def __init__(self, output_filename=None, line_count_restriction=1.0):
self.output_filename = output_filename
self.line_count_restriction = line_count_restriction

def start(self, action_name):
def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
self.profiled_actions[action_name].enable()

def stop(self, action_name):
def stop(self, action_name: str) -> None:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError( # pragma: no-cover
Expand Down
127 changes: 9 additions & 118 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,13 @@
"""
Profiling your training run can help you understand if there are any bottlenecks in your code.
Built-in checks
----------------
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
- on_epoch_start
- on_epoch_end
- on_batch_start
- tbptt_split_batch
- model_forward
- model_backward
- on_after_backward
- optimizer_step
- on_batch_end
- training_step_end
- on_training_end
Enable simple profiling
-------------------------
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
.. code-block:: python
trainer = Trainer(..., profiler=True)
The profiler's results will be printed at the completion of a training `fit()`.
.. code-block:: python
Profiler Report
Action | Mean duration (s) | Total time (s)
-----------------------------------------------------------------
on_epoch_start | 5.993e-06 | 5.993e-06
get_train_batch | 0.0087412 | 16.398
on_batch_start | 5.0865e-06 | 0.0095372
model_forward | 0.0017818 | 3.3408
model_backward | 0.0018283 | 3.4282
on_after_backward | 4.2862e-06 | 0.0080366
optimizer_step | 0.0011072 | 2.0759
on_batch_end | 4.5202e-06 | 0.0084753
on_epoch_end | 3.919e-06 | 3.919e-06
on_train_end | 5.449e-06 | 5.449e-06
Advanced Profiling
--------------------
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile
.. code-block:: python
profiler = AdvancedProfiler()
trainer = Trainer(..., profiler=profiler)
The profiler's results will be printed at the completion of a training `fit()`. This profiler
report can be quite long, so you can also specify an `output_filename` to save the report instead
of logging it to the output in your terminal. The output below shows the profiling for the action
`get_train_batch`.
.. code-block:: python
Profiler Report
Profile stats for: get_train_batch
4869394 function calls (4863767 primitive calls) in 18.893 seconds
Ordered by: cumulative time
List reduced from 76 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next}
1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__)
1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data)
1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch)
1875 0.084 0.000 18.290 0.010 fetch.py:44(<listcomp>)
60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__)
60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__)
60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__)
60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor)
60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__)
You can also reference this profiler in your LightningModule to profile specific actions of interest.
If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler`
which will allow you to skip profiling without having to make any code changes. Each profiler has a
method `profile()` which returns a context handler. Simply pass in the name of your action that you want
to track and the profiler will record performance for code executed within this context.
.. code-block:: python
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
class MyModel(LightningModule):
def __init__(self, hparams, profiler=None):
self.hparams = hparams
self.profiler = profiler or PassThroughProfiler()
def custom_processing_step(self, data):
with profiler.profile('my_custom_action'):
# custom processing step
return data
profiler = Profiler()
model = MyModel(hparams, profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)
.. warning:: `profiler` package has been renamed to `debugging` since v0.6.0.
The deprecated module name will be removed in v0.8.0.
"""

from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler
import warnings

warnings.warn("`profiler` package has been renamed to `debugging` since v0.7.2."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

__all__ = [
'Profiler',
'AdvancedProfiler',
'PassThroughProfiler',
]
from pytorch_lightning.debugging import ( # noqa: E402
BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler
)
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.debugging import BaseProfiler, Profiler, PassThroughProfiler
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
Expand Down
2 changes: 2 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402

from pytorch_lightning.profiler import Profiler, AdvancedProfiler # noqa: F402


class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):

Expand Down
4 changes: 2 additions & 2 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import tempfile
import time
from pathlib import Path

import numpy as np
import pytest
from pytorch_lightning.profiler import AdvancedProfiler, Profiler

from pytorch_lightning.debugging import AdvancedProfiler, Profiler

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001

Expand Down

0 comments on commit b4c552f

Please sign in to comment.