Skip to content

Commit

Permalink
Profiler summary (Lightning-AI#1259)
Browse files Browse the repository at this point in the history
* refactor and add types

* add Prorfiler summary

* fix imports

* Revert "refactor and add types"

This reverts commit b4c552f

* changelog

* revert rename

* fix test

* mute verbose
  • Loading branch information
Borda authored and akarnachev committed Apr 3, 2020
1 parent b05a209 commit 68e3b3f
Show file tree
Hide file tree
Showing 20 changed files with 113 additions and 59 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))

Expand Down Expand Up @@ -74,7 +75,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019))
- Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), )
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import torch_xla.core.xla_model as xm
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CometLogger(LightningLoggerBase):
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Built-in checks
----------------
---------------
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
Expand All @@ -20,7 +20,7 @@
- on_training_end
Enable simple profiling
-------------------------
-----------------------
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
Expand Down Expand Up @@ -113,10 +113,11 @@ def custom_processing_step(self, data):
"""

from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler
from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler

__all__ = [
'Profiler',
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cProfile
import io
import os
import pstats
import time
from abc import ABC, abstractmethod
Expand All @@ -16,6 +17,18 @@ class BaseProfiler(ABC):
If you wish to write a custom profiler, you should inhereit from this class.
"""

def __init__(self, output_streams: list = None):
"""
Params:
stream_out: callable
"""
if output_streams:
if not isinstance(output_streams, (list, tuple)):
output_streams = [output_streams]
else:
output_streams = []
self.write_streams = output_streams

@abstractmethod
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""
Expand Down Expand Up @@ -57,7 +70,12 @@ def profile_iterable(self, iterable, action_name: str) -> None:

def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
pass
for write in self.write_streams:
write(self.summary())

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""


class PassThroughProfiler(BaseProfiler):
Expand All @@ -67,25 +85,39 @@ class PassThroughProfiler(BaseProfiler):
"""

def __init__(self):
pass
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
pass

def stop(self, action_name: str) -> None:
pass

def summary(self) -> str:
return ""


class Profiler(BaseProfiler):
class SimpleProfiler(BaseProfiler):
"""
This profiler simply records the duration of actions (in seconds) and reports
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self):
def __init__(self, output_filename: str = None):
"""
Params:
output_filename (str): optionally save profile results to file instead of printing
to std out when training is finished.
"""
self.current_actions = {}
self.recorded_durations = defaultdict(list)

self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
Expand All @@ -103,20 +135,31 @@ def stop(self, action_name: str) -> None:
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)

def describe(self) -> None:
def summary(self) -> str:
output_string = "\n\nProfiler Report\n"

def log_row(action, mean, total):
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"

output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"\n{'-' * 65}"
output_string += f"{os.linesep}{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
)
output_string += "\n"
log.info(output_string)
output_string += os.linesep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class AdvancedProfiler(BaseProfiler):
Expand All @@ -136,9 +179,14 @@ def __init__(self, output_filename: str = None, line_count_restriction: float =
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
"""
self.profiled_actions = {}
self.output_filename = output_filename
self.line_count_restriction = line_count_restriction

self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
Expand All @@ -152,22 +200,28 @@ def stop(self, action_name: str) -> None:
)
pr.disable()

def describe(self) -> None:
self.recorded_stats = {}
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
self.recorded_stats[action_name] = s.getvalue()
if self.output_filename is not None:
# save to file
with open(self.output_filename, "w") as f:
for action, stats in self.recorded_stats.items():
f.write(f"Profile stats for: {action}")
f.write(stats)
else:
# log to standard out
output_string = "\nProfiler Report\n"
for action, stats in self.recorded_stats.items():
output_string += f"\nProfile stats for: {action}\n{stats}"
log.info(output_string)
recorded_stats[action_name] = s.getvalue()

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def train_fx(trial_hparams, cluster_manager, _):
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import torch_xla.distributed.parallel_loader as xla_pl
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
Expand All @@ -33,7 +32,7 @@
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
Expand Down Expand Up @@ -364,7 +363,7 @@ def __init__(

# configure profiler
if profiler is True:
profiler = Profiler()
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()

# configure early stop callback
Expand Down Expand Up @@ -490,10 +489,10 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
<class 'NoneType'>),
None),
...
...
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel


Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
parse_gpu_ids,
determine_root_gpu_device,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel

PRETEND_N_OF_GPUS = 16
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
LightningTestModelWithoutHyperparametersArg,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402

from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402


class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):

Expand Down
Loading

0 comments on commit 68e3b3f

Please sign in to comment.