From 52ec2b1b0e28df643f58042ac57297cab0e1fc8f Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 27 Mar 2020 11:39:45 +0100 Subject: [PATCH 1/8] refactor and add types --- pytorch_lightning/debugging/__init__.py | 123 +++++++++++++++++ .../debugging.py => debugging/exceptions.py} | 0 .../profiler.py => debugging/profilers.py} | 7 +- pytorch_lightning/profiler/__init__.py | 127 ++---------------- pytorch_lightning/trainer/trainer.py | 3 +- tests/test_deprecated.py | 2 + tests/test_profiler.py | 4 +- 7 files changed, 143 insertions(+), 123 deletions(-) create mode 100644 pytorch_lightning/debugging/__init__.py rename pytorch_lightning/{utilities/debugging.py => debugging/exceptions.py} (100%) rename pytorch_lightning/{profiler/profiler.py => debugging/profilers.py} (95%) diff --git a/pytorch_lightning/debugging/__init__.py b/pytorch_lightning/debugging/__init__.py new file mode 100644 index 0000000000000..24779f0716edc --- /dev/null +++ b/pytorch_lightning/debugging/__init__.py @@ -0,0 +1,123 @@ +""" +Profiling your training run can help you understand if there are any bottlenecks in your code. + + +Built-in checks +---------------- + +PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: + +- on_epoch_start +- on_epoch_end +- on_batch_start +- tbptt_split_batch +- model_forward +- model_backward +- on_after_backward +- optimizer_step +- on_batch_end +- training_step_end +- on_training_end + +Enable simple profiling +------------------------- + +If you only wish to profile the standard actions, you can set `profiler=True` when constructing +your `Trainer` object. + +.. code-block:: python + + trainer = Trainer(..., profiler=True) + +The profiler's results will be printed at the completion of a training `fit()`. + +.. code-block:: python + + Profiler Report + + Action | Mean duration (s) | Total time (s) + ----------------------------------------------------------------- + on_epoch_start | 5.993e-06 | 5.993e-06 + get_train_batch | 0.0087412 | 16.398 + on_batch_start | 5.0865e-06 | 0.0095372 + model_forward | 0.0017818 | 3.3408 + model_backward | 0.0018283 | 3.4282 + on_after_backward | 4.2862e-06 | 0.0080366 + optimizer_step | 0.0011072 | 2.0759 + on_batch_end | 4.5202e-06 | 0.0084753 + on_epoch_end | 3.919e-06 | 3.919e-06 + on_train_end | 5.449e-06 | 5.449e-06 + + +Advanced Profiling +-------------------- + +If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. +This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. + +.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile + +.. code-block:: python + + profiler = AdvancedProfiler() + trainer = Trainer(..., profiler=profiler) + +The profiler's results will be printed at the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. The output below shows the profiling for the action +`get_train_batch`. + +.. code-block:: python + + Profiler Report + + Profile stats for: get_train_batch + 4869394 function calls (4863767 primitive calls) in 18.893 seconds + Ordered by: cumulative time + List reduced from 76 to 10 due to restriction <10> + ncalls tottime percall cumtime percall filename:lineno(function) + 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} + 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) + 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) + 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) + 1875 0.084 0.000 18.290 0.010 fetch.py:44() + 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) + 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) + 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) + 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) + 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) + +You can also reference this profiler in your LightningModule to profile specific actions of interest. +If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` +which will allow you to skip profiling without having to make any code changes. Each profiler has a +method `profile()` which returns a context handler. Simply pass in the name of your action that you want +to track and the profiler will record performance for code executed within this context. + +.. code-block:: python + + from pytorch_lightning.profiler import Profiler, PassThroughProfiler + + class MyModel(LightningModule): + def __init__(self, hparams, profiler=None): + self.hparams = hparams + self.profiler = profiler or PassThroughProfiler() + + def custom_processing_step(self, data): + with profiler.profile('my_custom_action'): + # custom processing step + return data + + profiler = Profiler() + model = MyModel(hparams, profiler) + trainer = Trainer(profiler=profiler, max_epochs=1) + +""" + +from pytorch_lightning.debugging.profilers import BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler + +__all__ = [ + 'BaseProfiler', + 'Profiler', + 'AdvancedProfiler', + 'PassThroughProfiler', +] diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/debugging/exceptions.py similarity index 100% rename from pytorch_lightning/utilities/debugging.py rename to pytorch_lightning/debugging/exceptions.py diff --git a/pytorch_lightning/profiler/profiler.py b/pytorch_lightning/debugging/profilers.py similarity index 95% rename from pytorch_lightning/profiler/profiler.py rename to pytorch_lightning/debugging/profilers.py index e2a6e5b200d56..bb05a8f394a43 100644 --- a/pytorch_lightning/profiler/profiler.py +++ b/pytorch_lightning/debugging/profilers.py @@ -82,8 +82,13 @@ class Profiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self): + def __init__(self, output_filename: str = None): + """ + :param output_filename (str): optionally save profile results to file instead of printing + to std out when training is finished. + """ self.current_actions = {} + self.output_filename = output_filename self.recorded_durations = defaultdict(list) def start(self, action_name: str) -> None: diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index 20eece9ad66cb..69ed003f068bf 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -1,122 +1,13 @@ """ -Profiling your training run can help you understand if there are any bottlenecks in your code. - - -Built-in checks ----------------- - -PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: - -- on_epoch_start -- on_epoch_end -- on_batch_start -- tbptt_split_batch -- model_forward -- model_backward -- on_after_backward -- optimizer_step -- on_batch_end -- training_step_end -- on_training_end - -Enable simple profiling -------------------------- - -If you only wish to profile the standard actions, you can set `profiler=True` when constructing -your `Trainer` object. - -.. code-block:: python - - trainer = Trainer(..., profiler=True) - -The profiler's results will be printed at the completion of a training `fit()`. - -.. code-block:: python - - Profiler Report - - Action | Mean duration (s) | Total time (s) - ----------------------------------------------------------------- - on_epoch_start | 5.993e-06 | 5.993e-06 - get_train_batch | 0.0087412 | 16.398 - on_batch_start | 5.0865e-06 | 0.0095372 - model_forward | 0.0017818 | 3.3408 - model_backward | 0.0018283 | 3.4282 - on_after_backward | 4.2862e-06 | 0.0080366 - optimizer_step | 0.0011072 | 2.0759 - on_batch_end | 4.5202e-06 | 0.0084753 - on_epoch_end | 3.919e-06 | 3.919e-06 - on_train_end | 5.449e-06 | 5.449e-06 - - -Advanced Profiling --------------------- - -If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. -This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. - -.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile - -.. code-block:: python - - profiler = AdvancedProfiler() - trainer = Trainer(..., profiler=profiler) - -The profiler's results will be printed at the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. The output below shows the profiling for the action -`get_train_batch`. - -.. code-block:: python - - Profiler Report - - Profile stats for: get_train_batch - 4869394 function calls (4863767 primitive calls) in 18.893 seconds - Ordered by: cumulative time - List reduced from 76 to 10 due to restriction <10> - ncalls tottime percall cumtime percall filename:lineno(function) - 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} - 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) - 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) - 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) - 1875 0.084 0.000 18.290 0.010 fetch.py:44() - 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) - 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) - 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) - 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) - 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) - -You can also reference this profiler in your LightningModule to profile specific actions of interest. -If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` -which will allow you to skip profiling without having to make any code changes. Each profiler has a -method `profile()` which returns a context handler. Simply pass in the name of your action that you want -to track and the profiler will record performance for code executed within this context. - -.. code-block:: python - - from pytorch_lightning.profiler import Profiler, PassThroughProfiler - - class MyModel(LightningModule): - def __init__(self, hparams, profiler=None): - self.hparams = hparams - self.profiler = profiler or PassThroughProfiler() - - def custom_processing_step(self, data): - with profiler.profile('my_custom_action'): - # custom processing step - return data - - profiler = Profiler() - model = MyModel(hparams, profiler) - trainer = Trainer(profiler=profiler, max_epochs=1) - +.. warning:: `profiler` package has been renamed to `debugging` since v0.6.0. + The deprecated module name will be removed in v0.8.0. """ -from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler +import warnings + +warnings.warn("`profiler` package has been renamed to `debugging` since v0.7.2." + " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) -__all__ = [ - 'Profiler', - 'AdvancedProfiler', - 'PassThroughProfiler', -] +from pytorch_lightning.debugging import ( # noqa: E402 + BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler +) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6b4b461af5dae..84a27c2bcf88b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -17,9 +17,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.debugging import BaseProfiler, Profiler, PassThroughProfiler from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.profiler import Profiler, PassThroughProfiler -from pytorch_lightning.profiler.profiler import BaseProfiler from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index a3b087c718166..5002d23427d53 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports(): from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402 from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 + from pytorch_lightning.profiler import Profiler, AdvancedProfiler # noqa: F402 + class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index e60476bc59a2e..7c7fb3d8f5b40 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,10 +1,10 @@ -import tempfile import time from pathlib import Path import numpy as np import pytest -from pytorch_lightning.profiler import AdvancedProfiler, Profiler + +from pytorch_lightning.debugging import AdvancedProfiler, Profiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 From 960f43d46d7b833ab4767c1fd2eb262b6273f8c3 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 27 Mar 2020 12:32:37 +0100 Subject: [PATCH 2/8] add Prorfiler summary --- pytorch_lightning/debugging/profilers.py | 91 +++++++++++++++++------- 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/debugging/profilers.py b/pytorch_lightning/debugging/profilers.py index bb05a8f394a43..8607561693c5e 100644 --- a/pytorch_lightning/debugging/profilers.py +++ b/pytorch_lightning/debugging/profilers.py @@ -1,5 +1,6 @@ import cProfile import io +import os import pstats import time from abc import ABC, abstractmethod @@ -16,6 +17,18 @@ class BaseProfiler(ABC): If you wish to write a custom profiler, you should inhereit from this class. """ + def __init__(self, output_streams: list = None): + """ + Params: + stream_out: callable + """ + if output_streams: + if not isinstance(output_streams, (list, tuple)): + output_streams = [output_streams] + else: + output_streams = [] + self.write_streams = output_streams + @abstractmethod def start(self, action_name: str) -> None: """Defines how to start recording an action.""" @@ -57,7 +70,12 @@ def profile_iterable(self, iterable, action_name: str) -> None: def describe(self) -> None: """Logs a profile report after the conclusion of the training run.""" - pass + for write in self.write_streams: + write(self.summary()) + + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" class PassThroughProfiler(BaseProfiler): @@ -67,7 +85,7 @@ class PassThroughProfiler(BaseProfiler): """ def __init__(self): - pass + super().__init__(output_streams=None) def start(self, action_name: str) -> None: pass @@ -75,6 +93,9 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pass + def summary(self) -> str: + return "" + class Profiler(BaseProfiler): """ @@ -84,13 +105,21 @@ class Profiler(BaseProfiler): def __init__(self, output_filename: str = None): """ - :param output_filename (str): optionally save profile results to file instead of printing - to std out when training is finished. + Params: + output_filename (str): optionally save profile results to file instead of printing + to std out when training is finished. """ self.current_actions = {} - self.output_filename = output_filename self.recorded_durations = defaultdict(list) + self.output_fname = output_filename + self.output_file = open(self.output_fname, 'w') if self.output_fname else None + + streaming_out = [log.info] + if self.output_file: + streaming_out.append(self.output_file.write) + super().__init__(output_streams=streaming_out) + def start(self, action_name: str) -> None: if action_name in self.current_actions: raise ValueError( @@ -108,20 +137,26 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def describe(self) -> None: + def summary(self) -> str: output_string = "\n\nProfiler Report\n" def log_row(action, mean, total): - return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"\n{'-' * 65}" + output_string += f"{os.linesep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row( action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}", ) - output_string += "\n" - log.info(output_string) + output_string += os.linesep + return output_string + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.flush() + self.output_file.close() class AdvancedProfiler(BaseProfiler): @@ -141,9 +176,16 @@ def __init__(self, output_filename: str = None, line_count_restriction: float = or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) """ self.profiled_actions = {} - self.output_filename = output_filename self.line_count_restriction = line_count_restriction + self.output_fname = output_filename + self.output_file = open(self.output_fname, 'w') if self.output_fname else None + + streaming_out = [log.info] + if self.output_file: + streaming_out.append(self.output_file.write) + super().__init__(output_streams=streaming_out) + def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -157,22 +199,23 @@ def stop(self, action_name: str) -> None: ) pr.disable() - def describe(self) -> None: + def summary(self) -> str: self.recorded_stats = {} for action_name, pr in self.profiled_actions.items(): s = io.StringIO() ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) self.recorded_stats[action_name] = s.getvalue() - if self.output_filename is not None: - # save to file - with open(self.output_filename, "w") as f: - for action, stats in self.recorded_stats.items(): - f.write(f"Profile stats for: {action}") - f.write(stats) - else: - # log to standard out - output_string = "\nProfiler Report\n" - for action, stats in self.recorded_stats.items(): - output_string += f"\nProfile stats for: {action}\n{stats}" - log.info(output_string) + + # log to standard out + output_string = f"{os.linesep}Profiler Report{os.linesep}" + for action, stats in self.recorded_stats.items(): + output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" + + return output_string + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.flush() + self.output_file.close() From 086ea21656cf47040a620f907d2a04f1f782b04e Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 27 Mar 2020 12:41:14 +0100 Subject: [PATCH 3/8] fix imports --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/debugging/__init__.py | 2 ++ pytorch_lightning/debugging/profilers.py | 6 +++--- pytorch_lightning/loggers/comet.py | 2 +- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/distrib_data_parallel.py | 2 +- pytorch_lightning/trainer/distrib_parts.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/loggers/test_comet.py | 2 +- tests/models/test_amp.py | 2 +- tests/models/test_gpu.py | 2 +- tests/models/test_restore.py | 2 +- tests/trainer/test_dataloaders.py | 2 +- tests/trainer/test_trainer.py | 2 +- 15 files changed, 18 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 92dc78782b930..f2c9258b2e9d1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException try: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/debugging/__init__.py b/pytorch_lightning/debugging/__init__.py index 24779f0716edc..54ae3ed60bd9c 100644 --- a/pytorch_lightning/debugging/__init__.py +++ b/pytorch_lightning/debugging/__init__.py @@ -113,9 +113,11 @@ def custom_processing_step(self, data): """ +from pytorch_lightning.debugging.exceptions import MisconfigurationException from pytorch_lightning.debugging.profilers import BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler __all__ = [ + 'MisconfigurationException', 'BaseProfiler', 'Profiler', 'AdvancedProfiler', diff --git a/pytorch_lightning/debugging/profilers.py b/pytorch_lightning/debugging/profilers.py index 8607561693c5e..2ad3e9eaa1347 100644 --- a/pytorch_lightning/debugging/profilers.py +++ b/pytorch_lightning/debugging/profilers.py @@ -200,16 +200,16 @@ def stop(self, action_name: str) -> None: pr.disable() def summary(self) -> str: - self.recorded_stats = {} + recorded_stats = {} for action_name, pr in self.profiled_actions.items(): s = io.StringIO() ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) - self.recorded_stats[action_name] = s.getvalue() + recorded_stats[action_name] = s.getvalue() # log to standard out output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in self.recorded_stats.items(): + for action, stats in recorded_stats.items(): output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" return output_string diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 0109f9dbdd7a3..e35136b2f2508 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -28,7 +28,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException class CometLogger(LightningLoggerBase): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 039c7ee588b45..59614707719b0 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -6,7 +6,7 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 182b5404b3b6e..90ee5b80eb27a 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -122,7 +122,7 @@ def train_fx(trial_hparams, cluster_manager, _): import torch from pytorch_lightning import _logger as log from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 4718407a5f72f..befd838906a87 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -344,7 +344,7 @@ LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d720f81e6ebbe..2b820203b5d0d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -135,7 +135,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException try: import torch_xla.distributed.parallel_loader as xla_pl diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f01e2294c4884..1742f9dc3cd87 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -145,7 +145,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningMean try: diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 1aaf4cb7fd62f..a28e75b067636 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -8,7 +8,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from tests.base import LightningTestModel diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index eea3bf98653c1..dd1449a4b9f38 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -5,7 +5,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from tests.base import ( LightningTestModel, ) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 0ee77351b1665..7ac8941b54de8 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -11,7 +11,7 @@ parse_gpu_ids, determine_root_gpu_device, ) -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from tests.base import LightningTestModel PRETEND_N_OF_GPUS = 16 diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 07434274c28de..a4a018f371253 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -8,7 +8,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from tests.base import ( LightningTestModel, LightningTestModelWithoutHyperparametersArg, diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6d2332cdf538b..f3cd9af62846c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -2,7 +2,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from tests.base import ( TestModelBase, LightningTestModel, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 295ea3bdce036..94fb992c37a69 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,7 +14,7 @@ ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.debugging import MisconfigurationException from tests.base import ( TestModelBase, DictHparamsModel, From 77d905ccf2458bf346dd753895c188bac636f07a Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 27 Mar 2020 14:11:04 +0100 Subject: [PATCH 4/8] Revert "refactor and add types" This reverts commit b4c552fa --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loggers/comet.py | 2 +- pytorch_lightning/profiler/__init__.py | 13 ------------- .../{debugging => profiling}/__init__.py | 4 +--- .../{debugging => profiling}/profilers.py | 0 pytorch_lightning/trainer/__init__.py | 6 +++--- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/distrib_data_parallel.py | 2 +- pytorch_lightning/trainer/distrib_parts.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 2 +- .../{debugging => utilities}/exceptions.py | 0 tests/loggers/test_comet.py | 2 +- tests/models/test_amp.py | 2 +- tests/models/test_gpu.py | 2 +- tests/models/test_restore.py | 2 +- tests/test_deprecated.py | 2 -- tests/test_profiler.py | 4 ++-- tests/trainer/test_dataloaders.py | 2 +- tests/trainer/test_trainer.py | 2 +- 21 files changed, 21 insertions(+), 38 deletions(-) delete mode 100644 pytorch_lightning/profiler/__init__.py rename pytorch_lightning/{debugging => profiling}/__init__.py (95%) rename pytorch_lightning/{debugging => profiling}/profilers.py (100%) rename pytorch_lightning/{debugging => utilities}/exceptions.py (100%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f2c9258b2e9d1..8cec41da2b6fd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index e35136b2f2508..ee9d65a73cbf1 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -28,7 +28,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException class CometLogger(LightningLoggerBase): diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py deleted file mode 100644 index 69ed003f068bf..0000000000000 --- a/pytorch_lightning/profiler/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -.. warning:: `profiler` package has been renamed to `debugging` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -import warnings - -warnings.warn("`profiler` package has been renamed to `debugging` since v0.7.2." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.debugging import ( # noqa: E402 - BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler -) diff --git a/pytorch_lightning/debugging/__init__.py b/pytorch_lightning/profiling/__init__.py similarity index 95% rename from pytorch_lightning/debugging/__init__.py rename to pytorch_lightning/profiling/__init__.py index 54ae3ed60bd9c..81629f289476c 100644 --- a/pytorch_lightning/debugging/__init__.py +++ b/pytorch_lightning/profiling/__init__.py @@ -113,11 +113,9 @@ def custom_processing_step(self, data): """ -from pytorch_lightning.debugging.exceptions import MisconfigurationException -from pytorch_lightning.debugging.profilers import BaseProfiler, Profiler, AdvancedProfiler, PassThroughProfiler +from pytorch_lightning.profiling.profilers import Profiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler __all__ = [ - 'MisconfigurationException', 'BaseProfiler', 'Profiler', 'AdvancedProfiler', diff --git a/pytorch_lightning/debugging/profilers.py b/pytorch_lightning/profiling/profilers.py similarity index 100% rename from pytorch_lightning/debugging/profilers.py rename to pytorch_lightning/profiling/profilers.py diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 5c75bfc414073..992e1a03bbe2c 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -611,15 +611,15 @@ def on_train_end(self): # default used by the Trainer trainer = Trainer(process_position=0) -profiler -^^^^^^^^ +profiling +^^^^^^^^^ To profile individual steps during training and assist in identifying bottlenecks. See the `profiler documentation `_. for more details. Example:: - from pytorch_lightning.profiler import Profiler, AdvancedProfiler + from pytorch_lightning.profiling import Profiler, AdvancedProfiler # default used by the Trainer trainer = Trainer(profiler=None) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 59614707719b0..83b59d21a7d24 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -6,7 +6,7 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 90ee5b80eb27a..95b9a6197454e 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -122,7 +122,7 @@ def train_fx(trial_hparams, cluster_manager, _): import torch from pytorch_lightning import _logger as log from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index befd838906a87..fc6007b75d7cd 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -344,7 +344,7 @@ LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2b820203b5d0d..4fef97f3c8b89 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -135,7 +135,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: import torch_xla.distributed.parallel_loader as xla_pl diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 84a27c2bcf88b..a07d2c3a98f05 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -17,8 +17,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.debugging import BaseProfiler, Profiler, PassThroughProfiler from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.profiling import Profiler, PassThroughProfiler, BaseProfiler from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -32,7 +32,7 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningMean try: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1742f9dc3cd87..02abae9b47ada 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -145,7 +145,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningMean try: diff --git a/pytorch_lightning/debugging/exceptions.py b/pytorch_lightning/utilities/exceptions.py similarity index 100% rename from pytorch_lightning/debugging/exceptions.py rename to pytorch_lightning/utilities/exceptions.py diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index a28e75b067636..771ca3b6e78c5 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -8,7 +8,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import LightningTestModel diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index dd1449a4b9f38..a51ea938bd5bd 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -5,7 +5,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( LightningTestModel, ) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 7ac8941b54de8..5de29f647dc86 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -11,7 +11,7 @@ parse_gpu_ids, determine_root_gpu_device, ) -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import LightningTestModel PRETEND_N_OF_GPUS = 16 diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index a4a018f371253..d1f5e6ecac5d1 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -8,7 +8,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( LightningTestModel, LightningTestModelWithoutHyperparametersArg, diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 5002d23427d53..a3b087c718166 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -57,8 +57,6 @@ def test_tbd_remove_in_v0_9_0_module_imports(): from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402 from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 - from pytorch_lightning.profiler import Profiler, AdvancedProfiler # noqa: F402 - class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 7c7fb3d8f5b40..1835bf5491d06 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,10 +1,10 @@ +import tempfile import time from pathlib import Path import numpy as np import pytest - -from pytorch_lightning.debugging import AdvancedProfiler, Profiler +from pytorch_lightning.profiling import AdvancedProfiler, Profiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f3cd9af62846c..fd6f05cc92b6b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -2,7 +2,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( TestModelBase, LightningTestModel, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 94fb992c37a69..307365223dc11 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,7 +14,7 @@ ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin -from pytorch_lightning.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( TestModelBase, DictHparamsModel, From 442246849bfb3f763a2a281fa975727c02dbba60 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 27 Mar 2020 16:10:35 +0100 Subject: [PATCH 5/8] changelog --- CHANGELOG.md | 2 +- pytorch_lightning/profiler/__init__.py | 12 ++++++++++++ pytorch_lightning/profiling/__init__.py | 4 ++-- pytorch_lightning/profiling/profilers.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/test_deprecated.py | 2 ++ tests/test_profiler.py | 4 ++-- 7 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 pytorch_lightning/profiler/__init__.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 590095fa7f2cc..59c5a340dc6c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950)) - Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) - Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941)) -- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) +- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) - Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041)) - Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019)) - Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), ) diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py new file mode 100644 index 0000000000000..adcd74b1dd958 --- /dev/null +++ b/pytorch_lightning/profiler/__init__.py @@ -0,0 +1,12 @@ +""" +.. warning:: `profiler` package has been renamed to `profiling` since v0.7.2 and will be removed in v0.9.0 +""" + +import warnings + +warnings.warn("`profiler` package has been renamed to `profiling` since v0.7.2." + " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) + +from pytorch_lightning.profiling.profilers import ( # noqa: F403 + SimpleProfiler, AdvancedProfiler, BaseProfiler, PassThroughProfiler +) diff --git a/pytorch_lightning/profiling/__init__.py b/pytorch_lightning/profiling/__init__.py index 81629f289476c..97e4306fba177 100644 --- a/pytorch_lightning/profiling/__init__.py +++ b/pytorch_lightning/profiling/__init__.py @@ -113,11 +113,11 @@ def custom_processing_step(self, data): """ -from pytorch_lightning.profiling.profilers import Profiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler +from pytorch_lightning.profiling.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler __all__ = [ 'BaseProfiler', - 'Profiler', + 'SimpleProfiler', 'AdvancedProfiler', 'PassThroughProfiler', ] diff --git a/pytorch_lightning/profiling/profilers.py b/pytorch_lightning/profiling/profilers.py index 2ad3e9eaa1347..b2d2e480cc661 100644 --- a/pytorch_lightning/profiling/profilers.py +++ b/pytorch_lightning/profiling/profilers.py @@ -97,7 +97,7 @@ def summary(self) -> str: return "" -class Profiler(BaseProfiler): +class SimpleProfiler(BaseProfiler): """ This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action and the total time spent over the entire training run. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a07d2c3a98f05..9a67e1560be45 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,7 +18,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.profiling import Profiler, PassThroughProfiler, BaseProfiler +from pytorch_lightning.profiling import SimpleProfiler, PassThroughProfiler, BaseProfiler from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -363,7 +363,7 @@ def __init__( # configure profiler if profiler is True: - profiler = Profiler() + profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index a3b087c718166..378c0f915b78d 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports(): from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402 from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 + from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402 + class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 1835bf5491d06..1e8e62cc5b6bf 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from pytorch_lightning.profiling import AdvancedProfiler, Profiler +from pytorch_lightning.profiling import AdvancedProfiler, SimpleProfiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 @@ -25,7 +25,7 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = Profiler() + profiler = SimpleProfiler() return profiler From 8d8a70e378cc9c8f3a403fcc9b43d13dae1fad3e Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Sat, 28 Mar 2020 16:57:11 +0100 Subject: [PATCH 6/8] revert rename --- CHANGELOG.md | 1 + pytorch_lightning/profiler/__init__.py | 127 ++++++++++++++++-- .../{profiling => profiler}/profilers.py | 0 pytorch_lightning/profiling/__init__.py | 123 ----------------- pytorch_lightning/trainer/__init__.py | 6 +- pytorch_lightning/trainer/trainer.py | 2 +- tests/test_profiler.py | 2 +- 7 files changed, 125 insertions(+), 136 deletions(-) rename pytorch_lightning/{profiling => profiler}/profilers.py (100%) delete mode 100644 pytorch_lightning/profiling/__init__.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 59c5a340dc6c7..02110f6e68ecc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) +- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259)) - Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) ### Changed diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index adcd74b1dd958..683baccafa858 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -1,12 +1,123 @@ """ -.. warning:: `profiler` package has been renamed to `profiling` since v0.7.2 and will be removed in v0.9.0 -""" +Profiling your training run can help you understand if there are any bottlenecks in your code. + + +Built-in checks +--------------- + +PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: + +- on_epoch_start +- on_epoch_end +- on_batch_start +- tbptt_split_batch +- model_forward +- model_backward +- on_after_backward +- optimizer_step +- on_batch_end +- training_step_end +- on_training_end + +Enable simple profiling +----------------------- + +If you only wish to profile the standard actions, you can set `profiler=True` when constructing +your `Trainer` object. + +.. code-block:: python + + trainer = Trainer(..., profiler=True) + +The profiler's results will be printed at the completion of a training `fit()`. + +.. code-block:: python + + Profiler Report + + Action | Mean duration (s) | Total time (s) + ----------------------------------------------------------------- + on_epoch_start | 5.993e-06 | 5.993e-06 + get_train_batch | 0.0087412 | 16.398 + on_batch_start | 5.0865e-06 | 0.0095372 + model_forward | 0.0017818 | 3.3408 + model_backward | 0.0018283 | 3.4282 + on_after_backward | 4.2862e-06 | 0.0080366 + optimizer_step | 0.0011072 | 2.0759 + on_batch_end | 4.5202e-06 | 0.0084753 + on_epoch_end | 3.919e-06 | 3.919e-06 + on_train_end | 5.449e-06 | 5.449e-06 + + +Advanced Profiling +-------------------- -import warnings +If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. +This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. + +.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile + +.. code-block:: python + + profiler = AdvancedProfiler() + trainer = Trainer(..., profiler=profiler) + +The profiler's results will be printed at the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. The output below shows the profiling for the action +`get_train_batch`. + +.. code-block:: python + + Profiler Report + + Profile stats for: get_train_batch + 4869394 function calls (4863767 primitive calls) in 18.893 seconds + Ordered by: cumulative time + List reduced from 76 to 10 due to restriction <10> + ncalls tottime percall cumtime percall filename:lineno(function) + 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} + 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) + 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) + 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) + 1875 0.084 0.000 18.290 0.010 fetch.py:44() + 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) + 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) + 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) + 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) + 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) + +You can also reference this profiler in your LightningModule to profile specific actions of interest. +If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` +which will allow you to skip profiling without having to make any code changes. Each profiler has a +method `profile()` which returns a context handler. Simply pass in the name of your action that you want +to track and the profiler will record performance for code executed within this context. + +.. code-block:: python + + from pytorch_lightning.profiler import Profiler, PassThroughProfiler + + class MyModel(LightningModule): + def __init__(self, hparams, profiler=None): + self.hparams = hparams + self.profiler = profiler or PassThroughProfiler() + + def custom_processing_step(self, data): + with profiler.profile('my_custom_action'): + # custom processing step + return data + + profiler = Profiler() + model = MyModel(hparams, profiler) + trainer = Trainer(profiler=profiler, max_epochs=1) + +""" -warnings.warn("`profiler` package has been renamed to `profiling` since v0.7.2." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) +from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler -from pytorch_lightning.profiling.profilers import ( # noqa: F403 - SimpleProfiler, AdvancedProfiler, BaseProfiler, PassThroughProfiler -) +__all__ = [ + 'BaseProfiler', + 'SimpleProfiler', + 'AdvancedProfiler', + 'PassThroughProfiler', +] diff --git a/pytorch_lightning/profiling/profilers.py b/pytorch_lightning/profiler/profilers.py similarity index 100% rename from pytorch_lightning/profiling/profilers.py rename to pytorch_lightning/profiler/profilers.py diff --git a/pytorch_lightning/profiling/__init__.py b/pytorch_lightning/profiling/__init__.py deleted file mode 100644 index 97e4306fba177..0000000000000 --- a/pytorch_lightning/profiling/__init__.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Profiling your training run can help you understand if there are any bottlenecks in your code. - - -Built-in checks ----------------- - -PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: - -- on_epoch_start -- on_epoch_end -- on_batch_start -- tbptt_split_batch -- model_forward -- model_backward -- on_after_backward -- optimizer_step -- on_batch_end -- training_step_end -- on_training_end - -Enable simple profiling -------------------------- - -If you only wish to profile the standard actions, you can set `profiler=True` when constructing -your `Trainer` object. - -.. code-block:: python - - trainer = Trainer(..., profiler=True) - -The profiler's results will be printed at the completion of a training `fit()`. - -.. code-block:: python - - Profiler Report - - Action | Mean duration (s) | Total time (s) - ----------------------------------------------------------------- - on_epoch_start | 5.993e-06 | 5.993e-06 - get_train_batch | 0.0087412 | 16.398 - on_batch_start | 5.0865e-06 | 0.0095372 - model_forward | 0.0017818 | 3.3408 - model_backward | 0.0018283 | 3.4282 - on_after_backward | 4.2862e-06 | 0.0080366 - optimizer_step | 0.0011072 | 2.0759 - on_batch_end | 4.5202e-06 | 0.0084753 - on_epoch_end | 3.919e-06 | 3.919e-06 - on_train_end | 5.449e-06 | 5.449e-06 - - -Advanced Profiling --------------------- - -If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. -This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. - -.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile - -.. code-block:: python - - profiler = AdvancedProfiler() - trainer = Trainer(..., profiler=profiler) - -The profiler's results will be printed at the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. The output below shows the profiling for the action -`get_train_batch`. - -.. code-block:: python - - Profiler Report - - Profile stats for: get_train_batch - 4869394 function calls (4863767 primitive calls) in 18.893 seconds - Ordered by: cumulative time - List reduced from 76 to 10 due to restriction <10> - ncalls tottime percall cumtime percall filename:lineno(function) - 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} - 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) - 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) - 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) - 1875 0.084 0.000 18.290 0.010 fetch.py:44() - 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) - 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) - 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) - 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) - 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) - -You can also reference this profiler in your LightningModule to profile specific actions of interest. -If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` -which will allow you to skip profiling without having to make any code changes. Each profiler has a -method `profile()` which returns a context handler. Simply pass in the name of your action that you want -to track and the profiler will record performance for code executed within this context. - -.. code-block:: python - - from pytorch_lightning.profiler import Profiler, PassThroughProfiler - - class MyModel(LightningModule): - def __init__(self, hparams, profiler=None): - self.hparams = hparams - self.profiler = profiler or PassThroughProfiler() - - def custom_processing_step(self, data): - with profiler.profile('my_custom_action'): - # custom processing step - return data - - profiler = Profiler() - model = MyModel(hparams, profiler) - trainer = Trainer(profiler=profiler, max_epochs=1) - -""" - -from pytorch_lightning.profiling.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler - -__all__ = [ - 'BaseProfiler', - 'SimpleProfiler', - 'AdvancedProfiler', - 'PassThroughProfiler', -] diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 992e1a03bbe2c..5c75bfc414073 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -611,15 +611,15 @@ def on_train_end(self): # default used by the Trainer trainer = Trainer(process_position=0) -profiling -^^^^^^^^^ +profiler +^^^^^^^^ To profile individual steps during training and assist in identifying bottlenecks. See the `profiler documentation `_. for more details. Example:: - from pytorch_lightning.profiling import Profiler, AdvancedProfiler + from pytorch_lightning.profiler import Profiler, AdvancedProfiler # default used by the Trainer trainer = Trainer(profiler=None) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9a67e1560be45..879f753f68ab1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,7 +18,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.profiling import SimpleProfiler, PassThroughProfiler, BaseProfiler +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 1e8e62cc5b6bf..6c0fc103acc24 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from pytorch_lightning.profiling import AdvancedProfiler, SimpleProfiler +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 From 987df24f6b383050b332b31141e8bf4d0bdf9be7 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Sun, 29 Mar 2020 22:59:29 +0200 Subject: [PATCH 7/8] fix test --- pytorch_lightning/profiler/profilers.py | 14 ++++++++++++-- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/test_profiler.py | 13 +++++-------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index b2d2e480cc661..330b88e7444a4 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -152,10 +152,15 @@ def log_row(action, mean, total): output_string += os.linesep return output_string + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + def __del__(self): """Close profiler's stream.""" if self.output_file: - self.output_file.flush() self.output_file.close() @@ -214,8 +219,13 @@ def summary(self) -> str: return output_string + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + def __del__(self): """Close profiler's stream.""" if self.output_file: - self.output_file.flush() self.output_file.close() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 879f753f68ab1..33de9075ca562 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -489,10 +489,10 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: ('print_nan_grads', (,), False), ('process_position', (,), 0), ('profiler', - (, + (, ), None), - ... + ... """ trainer_default_params = inspect.signature(cls).parameters name_type_default = [] diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 6c0fc103acc24..ae5dc3eb36dee 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,4 +1,4 @@ -import tempfile +import os import time from pathlib import Path @@ -30,8 +30,8 @@ def simple_profiler(): @pytest.fixture -def advanced_profiler(): - profiler = AdvancedProfiler() +def advanced_profiler(tmpdir): + profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) return profiler @@ -168,12 +168,9 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): # record at least one event with advanced_profiler.profile("test"): pass - # log to stdout + # log to stdout and print to file advanced_profiler.describe() - # print to file - advanced_profiler.output_filename = Path(tmpdir, "profiler.txt") - advanced_profiler.describe() - data = Path(advanced_profiler.output_filename).read_text() + data = Path(advanced_profiler.output_fname).read_text() assert len(data) > 0 From a17f217bd538081c81adf3aa29d5bd1dd12e978d Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Tue, 31 Mar 2020 13:39:10 +0200 Subject: [PATCH 8/8] mute verbose --- pytorch_lightning/profiler/profilers.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 330b88e7444a4..6f6aa959ac451 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -115,9 +115,7 @@ def __init__(self, output_filename: str = None): self.output_fname = output_filename self.output_file = open(self.output_fname, 'w') if self.output_fname else None - streaming_out = [log.info] - if self.output_file: - streaming_out.append(self.output_file.write) + streaming_out = [self.output_file.write] if self.output_file else [log.info] super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: @@ -186,9 +184,7 @@ def __init__(self, output_filename: str = None, line_count_restriction: float = self.output_fname = output_filename self.output_file = open(self.output_fname, 'w') if self.output_fname else None - streaming_out = [log.info] - if self.output_file: - streaming_out.append(self.output_file.write) + streaming_out = [self.output_file.write] if self.output_file else [log.info] super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: