From 43b20b46a632cab938fb4c9171f1d950a3f87342 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 09:48:00 -0500 Subject: [PATCH 01/23] initial implementation --- pytorch_lightning/utilities/profiler.py | 162 ++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 pytorch_lightning/utilities/profiler.py diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py new file mode 100644 index 0000000000000..32c545d1111e2 --- /dev/null +++ b/pytorch_lightning/utilities/profiler.py @@ -0,0 +1,162 @@ +""" +Profiling your training run can help you understand if there are any bottlenecks in your code. + +PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: +- enumerate # TODO +- exaples # TODO +- here # TODO + +""" + + +from contextlib import contextmanager +from collections import defaultdict +import time +import numpy as np +import cProfile +import pstats +import io +from abc import ABC, abstractmethod + + +class BaseProfiler(ABC): + @abstractmethod + def start(self, action_name): + """ + defines how to start recording an action + """ + pass + + @abstractmethod + def stop(self, action_name): + """ + defines how to record the duration once an action is complete + """ + pass + + @contextmanager + def profile(self, action_name): + """ + yields a context manager to encapsulate the scope of a profiled action + + with self.profile('load training data'): + # load training data code + + the profiler will start once you've entered the context and automatically stop + once you exit the code block + """ + self.start(action_name) + yield action_name + self.stop(action_name) + + def describe(self): + """ + prints a report after the conclusion of the profiled training run + """ + pass + + +class Profiler(BaseProfiler): + """ + this profiler simply records the duration of actions (in seconds) and reports + the mean and standard deviation of each action duration over the entire training run + """ + + def __init__(self): + self.current_actions = {} + self.recorded_durations = defaultdict(list) + + def start(self, action_name): + if action_name in self.current_actions: + raise ValueError( + f"Attempted to start {action_name} which has already started." + ) + self.current_actions[action_name] = time.monotonic() + + def stop(self, action_name): + end_time = time.monotonic() + if action_name not in self.current_actions: + raise ValueError( + f"Attempting to stop recording an action ({action_name}) which was never started." + ) + start_time = self.current_actions.pop(action_name) + duration = end_time - start_time + self.recorded_durations[action_name].append(duration) + + def describe(self): + def print_row(action, mean, std_dev): + print(f"{action}\t|\t{mean:.4}\t|\t{std_dev:.4}") + + print_row("Action", "Mean duration", "Std deviation") + print("-" * 40) + for action, durations in self.recorded_durations.items(): + print_row(action, np.mean(durations), np.std(durations)) + + +class AdvancedProfiler(BaseProfiler): + """ + this profiler uses Python's cProfiler to record more detailed information about + time spent in each function call recorded during a given action + """ + def __init__(self): + self.profiled_actions = {} + + def start(self, action_name): + if action_name not in self.profiled_actions: + self.profiled_actions[action_name] = cProfile.Profile() + self.profiled_actions[action_name].enable() + + def stop(self, action_name): + pr = self.profiled_actions.get(action_name) + if pr is None: + raise ValueError( + f"Attempting to stop recording an action ({action_name}) which was never started." + ) + pr.disable() + + def describe(self, line_count_restriction=1.0): + self.recorded_stats = {} + for action_name, pr in self.profiled_actions.items(): + s = io.StringIO() + sortby = pstats.SortKey.CUMULATIVE + ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby) + ps.print_stats(line_count_restriction) + self.recorded_stats[action_name] = s.getvalue() + for action, stats in self.recorded_stats.items(): + print(f"Profile stats for: {action}") + print(stats) + + +if __name__ == '__main__.py': + + p = Profiler() + + with p.profile("test"): + time.sleep(5) + + with p.profile("test"): + time.sleep(2) + + with p.profile("test"): + time.sleep(4) + + with p.profile("ok"): + time.sleep(1) + + p.describe() + + ap = AdvancedProfiler() + + with ap.profile("test"): + time.sleep(5) + + with ap.profile("test"): + time.sleep(2) + + with ap.profile("test"): + time.sleep(4) + + with ap.profile("ok"): + time.sleep(1) + + ap.describe() From 903af112f12fef3c15884e8d1b202ec672e82f7b Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 11:54:01 -0500 Subject: [PATCH 02/23] formatting, pass through profiler, docstring --- pytorch_lightning/utilities/profiler.py | 99 ++++++++++++++++++------- 1 file changed, 74 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 32c545d1111e2..5eaa1be3dc73b 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -2,10 +2,29 @@ Profiling your training run can help you understand if there are any bottlenecks in your code. PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: -- enumerate # TODO -- exaples # TODO -- here # TODO - +- on_epoch_start +- on_epoch_end +- on_batch_start +- tbptt_split_batch +- model_forward +- model_backward +- on_after_backward +- optimizer_step +- on_batch_end +- training_end + +If you only wish to profile the standard actions, you can construct a Profiler object and simply +pass it into the Trainer. + +.. code-block:: python + profiler = Profiler() + trainer = Trainer(..., profiler=profiler) + +You can also reference this profiler to profiler any arbitrary code. + +.. code-block:: python + with profiler.profile('my_custom_action'): + my_custom_action() """ @@ -56,6 +75,21 @@ def describe(self): pass +class PassThroughProfiler(BaseProfiler): + """ + this can be used when you don't want to profile your runs + """ + + def __init__(self): + pass + + def start(self): + pass + + def stop(self): + pass + + class Profiler(BaseProfiler): """ this profiler simply records the duration of actions (in seconds) and reports @@ -85,12 +119,12 @@ def stop(self, action_name): def describe(self): def print_row(action, mean, std_dev): - print(f"{action}\t|\t{mean:.4}\t|\t{std_dev:.4}") + print(f"{action:<20s}\t| {mean:<15}\t| {std_dev:<15}") - print_row("Action", "Mean duration", "Std deviation") - print("-" * 40) + print_row("Action", "Mean duration (s)", "Std dev.") + print("-" * 60) for action, durations in self.recorded_durations.items(): - print_row(action, np.mean(durations), np.std(durations)) + print_row(action, f"{np.mean(durations):.5}", f"{np.std(durations):.5}") class AdvancedProfiler(BaseProfiler): @@ -98,6 +132,7 @@ class AdvancedProfiler(BaseProfiler): this profiler uses Python's cProfiler to record more detailed information about time spent in each function call recorded during a given action """ + def __init__(self): self.profiled_actions = {} @@ -127,36 +162,50 @@ def describe(self, line_count_restriction=1.0): print(stats) -if __name__ == '__main__.py': +if __name__ == "__main__.py": p = Profiler() - with p.profile("test"): + with p.profile("context handler"): time.sleep(5) + a = np.random.randn(3000, 2) + b = a + 2 + c = b / 3 - with p.profile("test"): - time.sleep(2) - - with p.profile("test"): - time.sleep(4) - - with p.profile("ok"): + with p.profile("context handler"): time.sleep(1) + a = np.random.randn(3000, 2) + b = a + 2 + c = b / 3 + + p.start("manual") + time.sleep(5) + a = np.random.randn(3000, 2) + b = a + 2 + c = b / 3 + p.stop("manual") p.describe() ap = AdvancedProfiler() - with ap.profile("test"): + with ap.profile("context handler"): time.sleep(5) + a = np.random.randn(3000, 2) + b = a + 2 + c = b / 3 - with ap.profile("test"): - time.sleep(2) - - with ap.profile("test"): - time.sleep(4) - - with ap.profile("ok"): + with ap.profile("context handler"): time.sleep(1) + a = np.random.randn(3000, 2) + b = a + 2 + c = b / 3 + + ap.start("manual") + time.sleep(5) + a = np.random.randn(3000, 2) + b = a + 2 + c = b / 3 + ap.stop("manual") ap.describe() From e90c8697d6c82e9e2d18e80fd543478df1c66e88 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 11:54:12 -0500 Subject: [PATCH 03/23] call profiler during training --- pytorch_lightning/trainer/trainer.py | 9 +++++- pytorch_lightning/trainer/training_loop.py | 35 ++++++++++++++-------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0b15100606f6c..b34354fc7438e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.profiler import PassThroughProfiler + try: from apex import amp @@ -87,6 +89,7 @@ def __init__( num_sanity_val_steps=5, truncated_bptt_steps=None, resume_from_checkpoint=None, + profiler=None ): r""" @@ -460,7 +463,8 @@ def __init__( # resume from a specific checkpoint trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') - + profiler (utilities.BaseProfiler): # TODO document + .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: - `nb_sanity_val_steps` @@ -564,6 +568,9 @@ def __init__( # configure logger self.configure_logger(logger) + # configure profiler + self.profiler = profiler or PassThroughProfiler() + # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e1f90308a76e5..601ec61a836ec 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -211,6 +211,7 @@ def __init__(self): self.training_tqdm_dict = None self.get_train_dataloader = None self.reduce_lr_on_plateau_scheduler = None + self.profiler = None @property def max_nb_epochs(self): @@ -369,7 +370,8 @@ def run_training_epoch(self): # before epoch hook if self.is_function_implemented('on_epoch_start'): model = self.get_model() - model.on_epoch_start() + with self.profiler.profile('on_epoch_start'): + model.on_epoch_start() # run epoch for batch_idx, batch in enumerate(self.get_train_dataloader()): @@ -430,7 +432,8 @@ def run_training_epoch(self): # epoch end hook if self.is_function_implemented('on_epoch_end'): model = self.get_model() - model.on_epoch_end() + with self.profiler.profile('on_epoch_end'): + model.on_epoch_end() def run_training_batch(self, batch, batch_idx): # track grad norms @@ -448,7 +451,8 @@ def run_training_batch(self, batch, batch_idx): # hook if self.is_function_implemented('on_batch_start'): model_ref = self.get_model() - response = model_ref.on_batch_start(batch) + with self.profiler.profile('on_batch_start'): + response = model_ref.on_batch_start(batch) if response == -1: return -1, grad_norm_dic, {} @@ -456,7 +460,8 @@ def run_training_batch(self, batch, batch_idx): splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() - splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) + with self.profiler.profile('tbptt_split_batch'): + splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) self.hiddens = None for split_idx, split_batch in enumerate(splits): @@ -476,8 +481,9 @@ def run_training_batch(self, batch, batch_idx): # wrap the forward step in a closure so second order methods work def optimizer_closure(): # forward pass - output = self.training_forward( - split_batch, batch_idx, opt_idx, self.hiddens) + with self.profiler.profile('model_forward'): + output = self.training_forward( + split_batch, batch_idx, opt_idx, self.hiddens) closure_loss = output[0] progress_bar_metrics = output[1] @@ -491,7 +497,8 @@ def optimizer_closure(): # backward pass model_ref = self.get_model() - model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx) + with self.profiler.profile('model_backward'): + model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx) # track metrics for callbacks all_callback_metrics.append(callback_metrics) @@ -503,7 +510,8 @@ def optimizer_closure(): # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() - model_ref.on_after_backward() + with self.profiler.profile('on_after_backward'): + model_ref.on_after_backward() return closure_loss @@ -533,8 +541,9 @@ def optimizer_closure(): # calls .step(), .zero_grad() # override function to modify this behavior model = self.get_model() - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, optimizer_closure) + with self.profiler.profile('optimizer_step'): + model.optimizer_step(self.current_epoch, batch_idx, + optimizer, opt_idx, optimizer_closure) # calculate running loss for display self.running_loss.append(self.batch_loss_value) @@ -544,7 +553,8 @@ def optimizer_closure(): # activate batch end hook if self.is_function_implemented('on_batch_end'): model = self.get_model() - model.on_batch_end() + with self.profiler.profile('on_batch_end'): + model.on_batch_end() # update progress bar self.main_progress_bar.update(1) @@ -604,7 +614,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): # allow any mode to define training_end if self.is_overriden('training_end'): model_ref = self.get_model() - output = model_ref.training_end(output) + with self.profiler.profile('training_end'): + output = model_ref.training_end(output) # format and reduce outputs accordingly output = self.process_output(output, train=True) From 51337320f42b7071f60469408ea79f8c301c8b55 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 16:44:46 -0500 Subject: [PATCH 04/23] add initial tests --- pytorch_lightning/utilities/profiler.py | 50 +------------------------ tests/test_profiler.py | 23 ++++++++++++ 2 files changed, 24 insertions(+), 49 deletions(-) create mode 100644 tests/test_profiler.py diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 5eaa1be3dc73b..4c118ed114da2 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -12,6 +12,7 @@ - optimizer_step - on_batch_end - training_end +- on_training_end If you only wish to profile the standard actions, you can construct a Profiler object and simply pass it into the Trainer. @@ -160,52 +161,3 @@ def describe(self, line_count_restriction=1.0): for action, stats in self.recorded_stats.items(): print(f"Profile stats for: {action}") print(stats) - - -if __name__ == "__main__.py": - - p = Profiler() - - with p.profile("context handler"): - time.sleep(5) - a = np.random.randn(3000, 2) - b = a + 2 - c = b / 3 - - with p.profile("context handler"): - time.sleep(1) - a = np.random.randn(3000, 2) - b = a + 2 - c = b / 3 - - p.start("manual") - time.sleep(5) - a = np.random.randn(3000, 2) - b = a + 2 - c = b / 3 - p.stop("manual") - - p.describe() - - ap = AdvancedProfiler() - - with ap.profile("context handler"): - time.sleep(5) - a = np.random.randn(3000, 2) - b = a + 2 - c = b / 3 - - with ap.profile("context handler"): - time.sleep(1) - a = np.random.randn(3000, 2) - b = a + 2 - c = b / 3 - - ap.start("manual") - time.sleep(5) - a = np.random.randn(3000, 2) - b = a + 2 - c = b / 3 - ap.stop("manual") - - ap.describe() diff --git a/tests/test_profiler.py b/tests/test_profiler.py new file mode 100644 index 0000000000000..b8d5f1fbc9b73 --- /dev/null +++ b/tests/test_profiler.py @@ -0,0 +1,23 @@ +from pytorch_lightning.utilities.profiler import Profiler, AdvancedProfiler +import time +import numpy as np + + +def test_simple_profiler(): + p = Profiler() + + with p.profile("a"): + time.sleep(3) + + with p.profile("a"): + time.sleep(1) + + with p.profile("b"): + time.sleep(2) + + with p.profile("c"): + time.sleep(1) + + np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=2) + np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=2) + np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=2) From e466e01e90281ae2cc1b6e4fd5737462839f2ca9 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 16:45:35 -0500 Subject: [PATCH 05/23] report stats when training is done --- pytorch_lightning/trainer/trainer.py | 3 +++ pytorch_lightning/trainer/training_loop.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b34354fc7438e..b054506b0548e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -890,6 +890,9 @@ def run_pretrain_routine(self, model): # CORE TRAINING LOOP self.train() + + # summarize profile results + self.profiler.describe() def test(self, model=None): r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 601ec61a836ec..dc79048a86bae 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -356,12 +356,14 @@ def train(self): stop = should_stop and met_min_epochs if stop: self.main_progress_bar.close() - model.on_train_end() + with self.profiler.profile('on_train_end'): + model.on_train_end() return self.main_progress_bar.close() - model.on_train_end() + with self.profiler.profile('on_train_end'): + model.on_train_end() if self.logger is not None: self.logger.finalize("success") From f229d189da74a5d712f94141c3721ade594633fc Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 17:58:56 -0500 Subject: [PATCH 06/23] fix formatting --- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b054506b0548e..789432e3b9f61 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -890,7 +890,7 @@ def run_pretrain_routine(self, model): # CORE TRAINING LOOP self.train() - + # summarize profile results self.profiler.describe() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dc79048a86bae..b598ec89a1934 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -545,7 +545,7 @@ def optimizer_closure(): model = self.get_model() with self.profiler.profile('optimizer_step'): model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, optimizer_closure) + optimizer, opt_idx, optimizer_closure) # calculate running loss for display self.running_loss.append(self.batch_loss_value) From d57275c49fab9053aadcc2219ad1c16f561cff9a Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 17:59:43 -0500 Subject: [PATCH 07/23] error handling, bugfix in passthroughprofiler --- pytorch_lightning/utilities/profiler.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 4c118ed114da2..5b7ebf8c62706 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -65,9 +65,11 @@ def profile(self, action_name): the profiler will start once you've entered the context and automatically stop once you exit the code block """ - self.start(action_name) - yield action_name - self.stop(action_name) + try: + self.start(action_name) + yield action_name + finally: + self.stop(action_name) def describe(self): """ @@ -84,10 +86,10 @@ class PassThroughProfiler(BaseProfiler): def __init__(self): pass - def start(self): + def start(self, action_name): pass - def stop(self): + def stop(self, action_name): pass From 9ec8618ff07a1f3251af9d88cbd1d2c379d4ddfd Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 18:22:04 -0500 Subject: [PATCH 08/23] finish documenting profiler arg in Trainer --- pytorch_lightning/trainer/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 789432e3b9f61..7026dd2aeefcf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -463,8 +463,17 @@ def __init__( # resume from a specific checkpoint trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') - profiler (utilities.BaseProfiler): # TODO document - + profiler (utilities.BaseProfiler): To profile individual steps during training and assist in + identifying bottlenecks. + Example:: + + # default used by the Trainer + trainer = Trainer(profiler=None) + + # profile a training run and get a report on completion of the training job + profiler = utilities.Profiler() + trainer = Trainer(profiler=profiler) + .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: - `nb_sanity_val_steps` From a2bedb001b169609309889928ac4d5b5d3ca4924 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 18:56:29 -0500 Subject: [PATCH 09/23] relax required precision for profiling tests --- tests/test_profiler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index b8d5f1fbc9b73..76f843d781085 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -18,6 +18,7 @@ def test_simple_profiler(): with p.profile("c"): time.sleep(1) - np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=2) - np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=2) - np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=2) + # different environments have different precision when it comes to time.sleep() + np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1) + np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1) + np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1) From 343ab83fe93c8ef62ec823e92ab34695924a9847 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Sun, 2 Feb 2020 19:05:03 -0500 Subject: [PATCH 10/23] option to dump cProfiler results to text file --- pytorch_lightning/utilities/profiler.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 5b7ebf8c62706..418b1dc8f7ec2 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -136,8 +136,13 @@ class AdvancedProfiler(BaseProfiler): time spent in each function call recorded during a given action """ - def __init__(self): + def __init__(self, output_filename=None): + ''' + :param output_filename (str): optionally save profile results to file instead of printing + to std out when training is finished. + ''' self.profiled_actions = {} + self.output_filename = output_filename def start(self, action_name): if action_name not in self.profiled_actions: @@ -160,6 +165,14 @@ def describe(self, line_count_restriction=1.0): ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby) ps.print_stats(line_count_restriction) self.recorded_stats[action_name] = s.getvalue() - for action, stats in self.recorded_stats.items(): - print(f"Profile stats for: {action}") - print(stats) + if self.output_filename is not None: + # save to file + with open(self.output_filename, 'w') as f: + for action, stats in self.recorded_stats.items(): + f.write(f"Profile stats for: {action}") + f.write(stats) + else: + # print to standard out + for action, stats in self.recorded_stats.items(): + print(f"Profile stats for: {action}") + print(stats) From 4e1129d9b80430168f8ef79906bdadde9f0beb0a Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Mon, 3 Feb 2020 18:23:23 -0500 Subject: [PATCH 11/23] use logging, format with black --- pytorch_lightning/utilities/profiler.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 418b1dc8f7ec2..5c1dd3c6d9a48 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -37,6 +37,9 @@ import pstats import io from abc import ABC, abstractmethod +import logging + +logger = logging.getLogger(__name__) class BaseProfiler(ABC): @@ -121,13 +124,13 @@ def stop(self, action_name): self.recorded_durations[action_name].append(duration) def describe(self): - def print_row(action, mean, std_dev): - print(f"{action:<20s}\t| {mean:<15}\t| {std_dev:<15}") + def log_row(action, mean, std_dev): + logger.info(f"{action:<20s}\t| {mean:<15}\t| {std_dev:<15}") - print_row("Action", "Mean duration (s)", "Std dev.") - print("-" * 60) + log_row("Action", "Mean duration (s)", "Std dev.") + logger.info("-" * 60) for action, durations in self.recorded_durations.items(): - print_row(action, f"{np.mean(durations):.5}", f"{np.std(durations):.5}") + log_row(action, f"{np.mean(durations):.5}", f"{np.std(durations):.5}") class AdvancedProfiler(BaseProfiler): @@ -137,10 +140,10 @@ class AdvancedProfiler(BaseProfiler): """ def __init__(self, output_filename=None): - ''' + """ :param output_filename (str): optionally save profile results to file instead of printing to std out when training is finished. - ''' + """ self.profiled_actions = {} self.output_filename = output_filename @@ -167,12 +170,11 @@ def describe(self, line_count_restriction=1.0): self.recorded_stats[action_name] = s.getvalue() if self.output_filename is not None: # save to file - with open(self.output_filename, 'w') as f: + with open(self.output_filename, "w") as f: for action, stats in self.recorded_stats.items(): f.write(f"Profile stats for: {action}") f.write(stats) else: # print to standard out for action, stats in self.recorded_stats.items(): - print(f"Profile stats for: {action}") - print(stats) + logger.info(f"Profile stats for: {action}\n{stats}") From 30fc655f371b6a0647db2c013f3f7745b5db16f0 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Mon, 3 Feb 2020 21:39:13 -0500 Subject: [PATCH 12/23] include profiler in docs --- docs/source/common-cases.rst | 12 +++++++++--- docs/source/profiler.rst | 9 +++++++++ pytorch_lightning/utilities/profiler.py | 4 ++++ 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 docs/source/profiler.rst diff --git a/docs/source/common-cases.rst b/docs/source/common-cases.rst index 7b96a93d84660..cc4ca362fc551 100644 --- a/docs/source/common-cases.rst +++ b/docs/source/common-cases.rst @@ -13,9 +13,15 @@ gradient clipping modifying training via hooks ============================= - - .. toctree:: :maxdepth: 3 - pl_examples \ No newline at end of file + pl_examples + + +profiling a training run +======================== +.. toctree:: + :maxdepth: 1 + + profiler \ No newline at end of file diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst new file mode 100644 index 0000000000000..7654e562ca34e --- /dev/null +++ b/docs/source/profiler.rst @@ -0,0 +1,9 @@ +.. role:: hidden + :class: hidden-section + +Profiler +=========== +.. automodule:: pytorch_lightning.utilities.profiler + :exclude-members: + _abc_impl, + summarize, diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 5c1dd3c6d9a48..19461e90ffdc0 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -2,6 +2,7 @@ Profiling your training run can help you understand if there are any bottlenecks in your code. PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: + - on_epoch_start - on_epoch_end - on_batch_start @@ -18,14 +19,17 @@ pass it into the Trainer. .. code-block:: python + profiler = Profiler() trainer = Trainer(..., profiler=profiler) You can also reference this profiler to profiler any arbitrary code. .. code-block:: python + with profiler.profile('my_custom_action'): my_custom_action() + """ From bd0f7c4fac4a6d1aa17493a9e86b0ac9d32f6771 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Mon, 3 Feb 2020 22:59:03 -0500 Subject: [PATCH 13/23] improved logging and better docs --- pytorch_lightning/utilities/profiler.py | 75 +++++++++++++++++-------- 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 19461e90ffdc0..5678f50b7ff51 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -23,12 +23,27 @@ profiler = Profiler() trainer = Trainer(..., profiler=profiler) -You can also reference this profiler to profiler any arbitrary code. +The profiler's results will be printed at the completion of a training `fit()`. + +You can also reference this profiler in your LightningModule to profile specific actions of interest. .. code-block:: python - with profiler.profile('my_custom_action'): - my_custom_action() + from pytorch_lightning.utilities.profiler import Profiler, PassThroughProfiler + + class MyModel(LightningModule): + def __init__(self, hparams, profiler=None): + self.hparams = hparams + self.profiler = profiler or PassThroughProfiler() + + def custom_processing_step(self, data): + with profiler.profile('my_custom_action'): + # custom processing step + return data + + profiler = Profiler() + model = MyModel(hparams, profiler) + trainer = Trainer(profiler=profiler, max_epochs=1) """ @@ -47,30 +62,36 @@ class BaseProfiler(ABC): + """ + If you wish to write a custom profiler, you should inhereit from this class. + """ + @abstractmethod def start(self, action_name): """ - defines how to start recording an action + Defines how to start recording an action. """ pass @abstractmethod def stop(self, action_name): """ - defines how to record the duration once an action is complete + Defines how to record the duration once an action is complete. """ pass @contextmanager def profile(self, action_name): """ - yields a context manager to encapsulate the scope of a profiled action + Yields a context manager to encapsulate the scope of a profiled action. + + Example:: - with self.profile('load training data'): - # load training data code + with self.profile('load training data'): + # load training data code - the profiler will start once you've entered the context and automatically stop - once you exit the code block + The profiler will start once you've entered the context and will automatically + stop once you exit the code block. """ try: self.start(action_name) @@ -80,14 +101,15 @@ def profile(self, action_name): def describe(self): """ - prints a report after the conclusion of the profiled training run + Logs a profile report after the conclusion of the training run. """ pass class PassThroughProfiler(BaseProfiler): """ - this can be used when you don't want to profile your runs + This class should be used when you don't want the (small) overhead of profiling. + The Trainer uses this class by default. """ def __init__(self): @@ -102,8 +124,8 @@ def stop(self, action_name): class Profiler(BaseProfiler): """ - this profiler simply records the duration of actions (in seconds) and reports - the mean and standard deviation of each action duration over the entire training run + This profiler simply records the duration of actions (in seconds) and reports + the mean and standard deviation of each action duration over the entire training run. """ def __init__(self): @@ -128,19 +150,26 @@ def stop(self, action_name): self.recorded_durations[action_name].append(duration) def describe(self): + output_string = "\nProfiler Report\n" + def log_row(action, mean, std_dev): - logger.info(f"{action:<20s}\t| {mean:<15}\t| {std_dev:<15}") + return f"\n{action:<20s}\t| {mean:<15}\t| {std_dev:<15}" - log_row("Action", "Mean duration (s)", "Std dev.") - logger.info("-" * 60) + output_string += log_row("Action", "Mean duration (s)", "Std dev.") + output_string += f"\n{'-' * 60}" for action, durations in self.recorded_durations.items(): - log_row(action, f"{np.mean(durations):.5}", f"{np.std(durations):.5}") + output_string += log_row( + action, f"{np.mean(durations):.5}", f"{np.std(durations):.5}" + ) + + logger.info(output_string) class AdvancedProfiler(BaseProfiler): """ - this profiler uses Python's cProfiler to record more detailed information about - time spent in each function call recorded during a given action + This profiler uses Python's cProfiler to record more detailed information about + time spent in each function call recorded during a given action. The output is quite + verbose and you should only use this if you want very detailed reports. """ def __init__(self, output_filename=None): @@ -179,6 +208,8 @@ def describe(self, line_count_restriction=1.0): f.write(f"Profile stats for: {action}") f.write(stats) else: - # print to standard out + # log to standard out + output_string = "\nProfiler Report\n" for action, stats in self.recorded_stats.items(): - logger.info(f"Profile stats for: {action}\n{stats}") + output_string += f"\nProfile stats for: {action}\n{stats}" + logger.info(output_string) From 266bd3bf394f6815b113ac872109df43dea0d284 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Mon, 3 Feb 2020 23:20:05 -0500 Subject: [PATCH 14/23] appease the linter --- pytorch_lightning/utilities/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 5678f50b7ff51..8d21b2d4b5691 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -90,7 +90,7 @@ def profile(self, action_name): with self.profile('load training data'): # load training data code - The profiler will start once you've entered the context and will automatically + The profiler will start once you've entered the context and will automatically stop once you exit the code block. """ try: From da6950c52a7774e61087c0c7b3d1ef6be2366534 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Tue, 4 Feb 2020 23:03:17 -0500 Subject: [PATCH 15/23] better summaries, wrapper for iterables --- pytorch_lightning/trainer/training_loop.py | 4 ++- pytorch_lightning/utilities/profiler.py | 32 ++++++++++++++++------ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b598ec89a1934..b95eff21b1b0d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -376,7 +376,9 @@ def run_training_epoch(self): model.on_epoch_start() # run epoch - for batch_idx, batch in enumerate(self.get_train_dataloader()): + for batch_idx, batch in self.profiler.profile_iterable( + enumerate(self.get_train_dataloader()), "get_train_batch" + ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 8d21b2d4b5691..724e49bc44808 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -99,6 +99,18 @@ def profile(self, action_name): finally: self.stop(action_name) + def profile_iterable(self, iterable, action_name): + iterator = iter(iterable) + while True: + try: + self.start(action_name) + value = next(iterator) + self.stop(action_name) + yield value + except StopIteration: + self.stop(action_name) + break + def describe(self): """ Logs a profile report after the conclusion of the training run. @@ -125,7 +137,7 @@ def stop(self, action_name): class Profiler(BaseProfiler): """ This profiler simply records the duration of actions (in seconds) and reports - the mean and standard deviation of each action duration over the entire training run. + the mean duration of each action and the total time spent over the entire training run. """ def __init__(self): @@ -150,18 +162,22 @@ def stop(self, action_name): self.recorded_durations[action_name].append(duration) def describe(self): - output_string = "\nProfiler Report\n" + output_string = "\n\nProfiler Report\n" - def log_row(action, mean, std_dev): - return f"\n{action:<20s}\t| {mean:<15}\t| {std_dev:<15}" + def log_row(action, mean, total): + return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}" - output_string += log_row("Action", "Mean duration (s)", "Std dev.") - output_string += f"\n{'-' * 60}" + output_string += log_row( + "Action", "Mean duration (s)", "Total time (s)" + ) + output_string += f"\n{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row( - action, f"{np.mean(durations):.5}", f"{np.std(durations):.5}" + action, + f"{np.mean(durations):.5}", + f"{np.sum(durations):.5}", ) - + output_string += "\n" logger.info(output_string) From 45c30c816b11e5866432438141190112275bdadd Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Wed, 5 Feb 2020 20:39:01 -0500 Subject: [PATCH 16/23] fix typo --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index b5e2fe9554b73..911ae78ad29b2 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -212,7 +212,7 @@ def evaluate(self, model, dataloaders, max_batches, test=False): # bookkeeping outputs = [] - # run training + # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] for batch_idx, batch in enumerate(dataloader): From 474785f6e488c9ca40e841c18553357413ea9bda Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Wed, 5 Feb 2020 22:08:07 -0500 Subject: [PATCH 17/23] allow profiler=True creation --- pytorch_lightning/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7026dd2aeefcf..d1b3be01f87a9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -26,7 +26,7 @@ from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException -from pytorch_lightning.utilities.profiler import PassThroughProfiler +from pytorch_lightning.utilities.profiler import Profiler, PassThroughProfiler try: @@ -578,6 +578,8 @@ def __init__( self.configure_logger(logger) # configure profiler + if profiler is True: + profiler = Profiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback From 1c23cefe7be729b428667bb3db50aeb94b29bac7 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Wed, 5 Feb 2020 22:09:36 -0500 Subject: [PATCH 18/23] more documentation --- docs/source/profiler.rst | 3 +- pytorch_lightning/utilities/profiler.py | 82 +++++++++++++++++++++---- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst index 7654e562ca34e..09c6633771167 100644 --- a/docs/source/profiler.rst +++ b/docs/source/profiler.rst @@ -1,7 +1,8 @@ .. role:: hidden :class: hidden-section -Profiler + +Profiling performance during training =========== .. automodule:: pytorch_lightning.utilities.profiler :exclude-members: diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/utilities/profiler.py index 724e49bc44808..9c3611eb47f3b 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/utilities/profiler.py @@ -15,17 +15,73 @@ - training_end - on_training_end -If you only wish to profile the standard actions, you can construct a Profiler object and simply -pass it into the Trainer. +If you only wish to profile the standard actions, you can set `profiler=True` when constructing +your `Trainer` object. .. code-block:: python - profiler = Profiler() - trainer = Trainer(..., profiler=profiler) + trainer = Trainer(..., profiler=True) The profiler's results will be printed at the completion of a training `fit()`. +.. code-block:: python + + Profiler Report + + Action | Mean duration (s) | Total time (s) + ----------------------------------------------------------------- + on_epoch_start | 5.993e-06 | 5.993e-06 + get_train_batch | 0.0087412 | 16.398 + on_batch_start | 5.0865e-06 | 0.0095372 + model_forward | 0.0017818 | 3.3408 + model_backward | 0.0018283 | 3.4282 + on_after_backward | 4.2862e-06 | 0.0080366 + optimizer_step | 0.0011072 | 2.0759 + on_batch_end | 4.5202e-06 | 0.0084753 + on_epoch_end | 3.919e-06 | 3.919e-06 + on_train_end | 5.449e-06 | 5.449e-06 + + +If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. +This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. + +.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile + +.. code-block:: python + + profiler = AdvancedProfiler() + trainer = Trainer(..., profiler=profiler) + +The profiler's results will be printed at the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. The output below shows the profiling for the action +`get_train_batch`. + +.. code-block:: python + + Profiler Report + + Profile stats for: get_train_batch + 4869394 function calls (4863767 primitive calls) in 18.893 seconds + Ordered by: cumulative time + List reduced from 76 to 10 due to restriction <10> + ncalls tottime percall cumtime percall filename:lineno(function) + 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} + 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) + 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) + 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) + 1875 0.084 0.000 18.290 0.010 fetch.py:44() + 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) + 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) + 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) + 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) + 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) + You can also reference this profiler in your LightningModule to profile specific actions of interest. +If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` +which will allow you to skip profiling without having to make any code changes. Each profiler has a +method `profile()` which returns a context handler. Simply pass in the name of your action that you want +to track and the profiler will record performance for code executed within this context. .. code-block:: python @@ -167,15 +223,11 @@ def describe(self): def log_row(action, mean, total): return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}" - output_string += log_row( - "Action", "Mean duration (s)", "Total time (s)" - ) + output_string += log_row("Action", "Mean duration (s)", "Total time (s)") output_string += f"\n{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row( - action, - f"{np.mean(durations):.5}", - f"{np.sum(durations):.5}", + action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}", ) output_string += "\n" logger.info(output_string) @@ -188,13 +240,17 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__(self, output_filename=None): + def __init__(self, output_filename=None, line_count_restriction=1.0): """ :param output_filename (str): optionally save profile results to file instead of printing to std out when training is finished. + :param line_count_restriction (int|float): this can be used to limit the number of functions + reported for each action. either an integer (to select a count of lines), + or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) """ self.profiled_actions = {} self.output_filename = output_filename + self.line_count_restriction = line_count_restriction def start(self, action_name): if action_name not in self.profiled_actions: @@ -209,13 +265,13 @@ def stop(self, action_name): ) pr.disable() - def describe(self, line_count_restriction=1.0): + def describe(self): self.recorded_stats = {} for action_name, pr in self.profiled_actions.items(): s = io.StringIO() sortby = pstats.SortKey.CUMULATIVE ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby) - ps.print_stats(line_count_restriction) + ps.print_stats(self.line_count_restriction) self.recorded_stats[action_name] = s.getvalue() if self.output_filename is not None: # save to file From 93b917dc10ac8eb871d0fa793640d4a1b40d72e7 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Wed, 5 Feb 2020 22:34:20 -0500 Subject: [PATCH 19/23] add tests for advanced profiler --- tests/test_profiler.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 76f843d781085..aec2920a7501f 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -22,3 +22,29 @@ def test_simple_profiler(): np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1) np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1) np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1) + + +def test_advanced_profiler(): + def get_duration(profile): + return sum([x.totaltime for x in profile.getstats()]) + + p = AdvancedProfiler() + + with p.profile("a"): + time.sleep(3) + + with p.profile("a"): + time.sleep(1) + + with p.profile("b"): + time.sleep(2) + + with p.profile("c"): + time.sleep(1) + + a_duration = get_duration(p.profiled_actions["a"]) + np.testing.assert_almost_equal(a_duration, [4], decimal=1) + b_duration = get_duration(p.profiled_actions["b"]) + np.testing.assert_almost_equal(b_duration, [2], decimal=1) + c_duration = get_duration(p.profiled_actions["c"]) + np.testing.assert_almost_equal(c_duration, [1], decimal=1) From c117ae17c743faa95986fd1e70f13a5a61b1cf9f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 6 Feb 2020 08:26:35 -0500 Subject: [PATCH 20/23] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d1b3be01f87a9..154bf726b5d28 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -469,11 +469,18 @@ def __init__( # default used by the Trainer trainer = Trainer(profiler=None) + + # to profile using the defaults + trainer = Trainer(profiler=True) - # profile a training run and get a report on completion of the training job + # profile main parts of a training run profiler = utilities.Profiler() trainer = Trainer(profiler=profiler) + # advanced profiler for function-level stats + profiler = utilities.AdvancedProfiler() + trainer = Trainer(profiler=profiler) + .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: - `nb_sanity_val_steps` From b9e842f3f997e925f18ce184d1cbd220d47641b5 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Thu, 6 Feb 2020 08:35:28 -0500 Subject: [PATCH 21/23] make profilers accessible in pl.utilities --- pytorch_lightning/utilities/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index e69de29bb2d1d..cb920cc5f3040 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -0,0 +1 @@ +from .profiler import Profiler, AdvancedProfiler From f118466a808e6cd187b8bc8a85ed59f12449af84 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Thu, 6 Feb 2020 18:47:50 -0500 Subject: [PATCH 22/23] reorg profiler files --- docs/source/profiler.rst | 2 +- pytorch_lightning/profiler/__init__.py | 112 ++++++++++++++++++ .../{utilities => profiler}/profiler.py | 106 ----------------- pytorch_lightning/trainer/trainer.py | 16 +-- pytorch_lightning/utilities/__init__.py | 1 - 5 files changed, 122 insertions(+), 115 deletions(-) create mode 100644 pytorch_lightning/profiler/__init__.py rename pytorch_lightning/{utilities => profiler}/profiler.py (57%) diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst index 09c6633771167..6443e7ddbc62f 100644 --- a/docs/source/profiler.rst +++ b/docs/source/profiler.rst @@ -4,7 +4,7 @@ Profiling performance during training =========== -.. automodule:: pytorch_lightning.utilities.profiler +.. automodule:: pytorch_lightning.profiler :exclude-members: _abc_impl, summarize, diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py new file mode 100644 index 0000000000000..a69e3ccf9c0d4 --- /dev/null +++ b/pytorch_lightning/profiler/__init__.py @@ -0,0 +1,112 @@ +""" +Profiling your training run can help you understand if there are any bottlenecks in your code. + +PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: + +- on_epoch_start +- on_epoch_end +- on_batch_start +- tbptt_split_batch +- model_forward +- model_backward +- on_after_backward +- optimizer_step +- on_batch_end +- training_end +- on_training_end + +If you only wish to profile the standard actions, you can set `profiler=True` when constructing +your `Trainer` object. + +.. code-block:: python + + trainer = Trainer(..., profiler=True) + +The profiler's results will be printed at the completion of a training `fit()`. + +.. code-block:: python + + Profiler Report + + Action | Mean duration (s) | Total time (s) + ----------------------------------------------------------------- + on_epoch_start | 5.993e-06 | 5.993e-06 + get_train_batch | 0.0087412 | 16.398 + on_batch_start | 5.0865e-06 | 0.0095372 + model_forward | 0.0017818 | 3.3408 + model_backward | 0.0018283 | 3.4282 + on_after_backward | 4.2862e-06 | 0.0080366 + optimizer_step | 0.0011072 | 2.0759 + on_batch_end | 4.5202e-06 | 0.0084753 + on_epoch_end | 3.919e-06 | 3.919e-06 + on_train_end | 5.449e-06 | 5.449e-06 + + +If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. +This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. + +.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile + +.. code-block:: python + + profiler = AdvancedProfiler() + trainer = Trainer(..., profiler=profiler) + +The profiler's results will be printed at the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. The output below shows the profiling for the action +`get_train_batch`. + +.. code-block:: python + + Profiler Report + + Profile stats for: get_train_batch + 4869394 function calls (4863767 primitive calls) in 18.893 seconds + Ordered by: cumulative time + List reduced from 76 to 10 due to restriction <10> + ncalls tottime percall cumtime percall filename:lineno(function) + 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} + 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) + 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) + 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) + 1875 0.084 0.000 18.290 0.010 fetch.py:44() + 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) + 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) + 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) + 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) + 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) + +You can also reference this profiler in your LightningModule to profile specific actions of interest. +If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` +which will allow you to skip profiling without having to make any code changes. Each profiler has a +method `profile()` which returns a context handler. Simply pass in the name of your action that you want +to track and the profiler will record performance for code executed within this context. + +.. code-block:: python + + from pytorch_lightning.profiler import Profiler, PassThroughProfiler + + class MyModel(LightningModule): + def __init__(self, hparams, profiler=None): + self.hparams = hparams + self.profiler = profiler or PassThroughProfiler() + + def custom_processing_step(self, data): + with profiler.profile('my_custom_action'): + # custom processing step + return data + + profiler = Profiler() + model = MyModel(hparams, profiler) + trainer = Trainer(profiler=profiler, max_epochs=1) + +""" + +from .profiler import Profiler, AdvancedProfiler, PassThroughProfiler + +__all__ = [ + 'Profiler', + 'AdvancedProfiler', + 'PassThroughProfiler', +] diff --git a/pytorch_lightning/utilities/profiler.py b/pytorch_lightning/profiler/profiler.py similarity index 57% rename from pytorch_lightning/utilities/profiler.py rename to pytorch_lightning/profiler/profiler.py index 9c3611eb47f3b..32f220897a9dc 100644 --- a/pytorch_lightning/utilities/profiler.py +++ b/pytorch_lightning/profiler/profiler.py @@ -1,109 +1,3 @@ -""" -Profiling your training run can help you understand if there are any bottlenecks in your code. - -PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: - -- on_epoch_start -- on_epoch_end -- on_batch_start -- tbptt_split_batch -- model_forward -- model_backward -- on_after_backward -- optimizer_step -- on_batch_end -- training_end -- on_training_end - -If you only wish to profile the standard actions, you can set `profiler=True` when constructing -your `Trainer` object. - -.. code-block:: python - - trainer = Trainer(..., profiler=True) - -The profiler's results will be printed at the completion of a training `fit()`. - -.. code-block:: python - - Profiler Report - - Action | Mean duration (s) | Total time (s) - ----------------------------------------------------------------- - on_epoch_start | 5.993e-06 | 5.993e-06 - get_train_batch | 0.0087412 | 16.398 - on_batch_start | 5.0865e-06 | 0.0095372 - model_forward | 0.0017818 | 3.3408 - model_backward | 0.0018283 | 3.4282 - on_after_backward | 4.2862e-06 | 0.0080366 - optimizer_step | 0.0011072 | 2.0759 - on_batch_end | 4.5202e-06 | 0.0084753 - on_epoch_end | 3.919e-06 | 3.919e-06 - on_train_end | 5.449e-06 | 5.449e-06 - - -If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. -This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. - -.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile - -.. code-block:: python - - profiler = AdvancedProfiler() - trainer = Trainer(..., profiler=profiler) - -The profiler's results will be printed at the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. The output below shows the profiling for the action -`get_train_batch`. - -.. code-block:: python - - Profiler Report - - Profile stats for: get_train_batch - 4869394 function calls (4863767 primitive calls) in 18.893 seconds - Ordered by: cumulative time - List reduced from 76 to 10 due to restriction <10> - ncalls tottime percall cumtime percall filename:lineno(function) - 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} - 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) - 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) - 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) - 1875 0.084 0.000 18.290 0.010 fetch.py:44() - 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) - 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) - 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) - 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) - 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) - -You can also reference this profiler in your LightningModule to profile specific actions of interest. -If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` -which will allow you to skip profiling without having to make any code changes. Each profiler has a -method `profile()` which returns a context handler. Simply pass in the name of your action that you want -to track and the profiler will record performance for code executed within this context. - -.. code-block:: python - - from pytorch_lightning.utilities.profiler import Profiler, PassThroughProfiler - - class MyModel(LightningModule): - def __init__(self, hparams, profiler=None): - self.hparams = hparams - self.profiler = profiler or PassThroughProfiler() - - def custom_processing_step(self, data): - with profiler.profile('my_custom_action'): - # custom processing step - return data - - profiler = Profiler() - model = MyModel(hparams, profiler) - trainer = Trainer(profiler=profiler, max_epochs=1) - -""" - - from contextlib import contextmanager from collections import defaultdict import time diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 154bf726b5d28..73c7e3fc1aef5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -26,7 +26,7 @@ from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException -from pytorch_lightning.utilities.profiler import Profiler, PassThroughProfiler +from pytorch_lightning.profiler import Profiler, PassThroughProfiler try: @@ -463,22 +463,24 @@ def __init__( # resume from a specific checkpoint trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') - profiler (utilities.BaseProfiler): To profile individual steps during training and assist in + profiler (BaseProfiler): To profile individual steps during training and assist in identifying bottlenecks. Example:: + from pytorch_lightning.profiler import Profiler, AdvancedProfiler + # default used by the Trainer trainer = Trainer(profiler=None) - - # to profile using the defaults + + # to profile standard training events trainer = Trainer(profiler=True) - # profile main parts of a training run - profiler = utilities.Profiler() + # equivalent to profiler=True + profiler = Profiler() trainer = Trainer(profiler=profiler) # advanced profiler for function-level stats - profiler = utilities.AdvancedProfiler() + profiler = AdvancedProfiler() trainer = Trainer(profiler=profiler) .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index cb920cc5f3040..e69de29bb2d1d 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1 +0,0 @@ -from .profiler import Profiler, AdvancedProfiler From 5ad693f20a256f03d664c891ceae7c69da70bbca Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Thu, 6 Feb 2020 19:09:30 -0500 Subject: [PATCH 23/23] change import for profiler tests --- tests/test_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index aec2920a7501f..d6e085a55e85f 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,4 +1,4 @@ -from pytorch_lightning.utilities.profiler import Profiler, AdvancedProfiler +from pytorch_lightning.profiler import Profiler, AdvancedProfiler import time import numpy as np