Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add PyTorch Profiler. #5560

Merged
merged 43 commits into from
Jan 26, 2021
Merged
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
ad00b97
add profiler
tchaton Jan 18, 2021
cfae67b
add profiler
tchaton Jan 18, 2021
2f1020c
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 18, 2021
5931c18
update
tchaton Jan 19, 2021
11d8c61
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 19, 2021
c85661a
resolve flake8
tchaton Jan 19, 2021
9a62eb8
update doc
tchaton Jan 19, 2021
6f54b69
update changelog
tchaton Jan 19, 2021
1bbe314
clean doc
tchaton Jan 19, 2021
bd035da
delete prof file
tchaton Jan 19, 2021
b0cfe7a
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 19, 2021
e689cda
merge pr codebase
tchaton Jan 21, 2021
803aaa2
update
tchaton Jan 21, 2021
2e91d9e
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 21, 2021
698b43a
update doc
tchaton Jan 21, 2021
991958f
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 21, 2021
da9a56d
update doc
tchaton Jan 21, 2021
3b119fd
update doc
tchaton Jan 21, 2021
75c966f
update on comments
tchaton Jan 22, 2021
c10ab8c
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 22, 2021
f6ae283
update docstring
tchaton Jan 22, 2021
29b9a58
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 22, 2021
f0aed96
update docstring
tchaton Jan 22, 2021
5dd2b4d
try
Borda Jan 22, 2021
03b3ea5
update test
tchaton Jan 22, 2021
5663b5f
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 22, 2021
1e6a953
Update pytorch_lightning/profiler/__init__.py
tchaton Jan 22, 2021
21ae2da
Update pytorch_lightning/profiler/__init__.py
tchaton Jan 22, 2021
f6f0d89
update on comments
tchaton Jan 22, 2021
7ca9b7c
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 22, 2021
2ea05de
remove old code
tchaton Jan 22, 2021
c8d24b8
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 24, 2021
c397603
add support for ddp
tchaton Jan 25, 2021
1db6e67
resolve flake8
tchaton Jan 25, 2021
4e9a86c
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 25, 2021
e9866bb
Update pytorch_lightning/profiler/__init__.py
tchaton Jan 25, 2021
d65beee
resolve tests
tchaton Jan 25, 2021
bb642ba
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 25, 2021
8338c5e
resolve flake8
tchaton Jan 25, 2021
85f9aa2
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 25, 2021
e6263e6
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 26, 2021
9ae56cc
resolve flake8
tchaton Jan 26, 2021
8d62f41
Merge branch 'release/1.2-dev' into feat/torch_profiler
mergify[bot] Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add profiler
tchaton committed Jan 18, 2021
commit ad00b975781e71570e94ef47347db875951e90d6
9 changes: 8 additions & 1 deletion pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -116,11 +116,18 @@ def custom_processing_step(self, data):

"""

from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler.profilers import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PytorchProfiler,
SimpleProfiler,
)

__all__ = [
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
"PytorchProfiler",
tchaton marked this conversation as resolved.
Show resolved Hide resolved
]
104 changes: 104 additions & 0 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
@@ -25,9 +25,11 @@
from typing import Optional, Union

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class BaseProfiler(ABC):
@@ -282,3 +284,105 @@ def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class PytorchProfiler(BaseProfiler):
"""
This profiler uses PyTorch's Autograd Profiler and let's you inspect the cost of
tchaton marked this conversation as resolved.
Show resolved Hide resolved
different operators inside your model - both on the CPU and GPU
"""

PROFILED_FUNCTIONS = ["training_step", "validation_step", "test_step"]

def __init__(self, output_filename: Optional[str] = None,
enabled=True,
use_cuda=False,
record_shapes=True,
profile_memory=True,
with_stack=True,
sort_by_key: str = "self_cuda_memory_usage"):
"""
Args:
output_filename: optionally save profile results to file instead of printing
to std out when training is finished.
line_count_restriction: this can be used to limit the number of functions
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
"""
self.profiled_actions = {}
self.enabled = enabled
self.use_cuda = use_cuda
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.sort_by_key = sort_by_key
if self.sort_by_key not in self.available_sort_by_keys:
raise MisconfigurationException(
f"Found sort_by_key: {sort_by_key}. Should be within {self.available_sort_by_keys}. ")

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions and action_name in self.PROFILED_FUNCTIONS:
self.profiled_actions[action_name] = torch.autograd.profiler.profile(
enabled=self.enabled,
use_cuda=self.use_cuda,
record_shapes=self.record_shapes,
profile_memory=self.profile_memory).__enter__()

def stop(self, action_name: str) -> None:
if action_name in self.PROFILED_FUNCTIONS:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
# todo: Find a better solution
try:
_ = pr.__exit__(None, None, None)
except RuntimeError as e:
if "Expected debug info of type 2" in str(e):
pass
else:
raise RuntimeError(str(e))

def summary(self) -> str:
recorded_stats = {}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for action_name, pr in self.profiled_actions.items():
table = self.profiled_actions[action_name].key_averages().table(sort_by=self.sort_by_key)
recorded_stats[action_name] = table

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += (
f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"
)

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

@property
def available_sort_by_keys(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return [
"cpu_time", "cuda_time", "cpu_time_total",
"cuda_time_total", "cpu_memory_usage", "cuda_memory_usage",
"self_cpu_memory_usage", "self_cuda_memory_usage", "count"
]
9 changes: 8 additions & 1 deletion pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
@@ -14,13 +14,20 @@

from typing import Union

from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler, AdvancedProfiler
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PytorchProfiler,
SimpleProfiler,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PytorchProfiler,
}


6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
@@ -171,10 +171,12 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
# run actual test step
if self.testing:
model_ref._current_fx_name = "test_step"
output = self.trainer.accelerator_backend.test_step(args)
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator_backend.test_step(args)
else:
model_ref._current_fx_name = "validation_step"
output = self.trainer.accelerator_backend.validation_step(args)
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator_backend.validation_step(args)

# capture any logged information
self.trainer.logger_connector.cache_logged_metrics()
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing, DeviceType
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
@@ -339,7 +339,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
# manually capture logged metrics
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
training_step_output = self.trainer.accelerator_backend.training_step(args)
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.logger_connector.cache_logged_metrics()

self._check_training_step_output(training_step_output)
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_PYTORCH_GREATER_EQUAL_1_7_0,
_RPC_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_XLA_AVAILABLE,
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -53,3 +53,4 @@ def _module_available(module_path: str) -> bool:
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_PYTORCH_GREATER_EQUAL_1_7_0 = LooseVersion(torch.__version__) >= LooseVersion("1.7.0")
25 changes: 24 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _PYTORCH_GREATER_EQUAL_1_7_0
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate
@@ -1421,6 +1421,7 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
('simple', SimpleProfiler),
('Simple', SimpleProfiler),
('advanced', AdvancedProfiler),
('pytorch', AdvancedProfiler),
])
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {'profiler': profiler} if profiler is not None else {}
@@ -1441,3 +1442,25 @@ def test_trainer_profiler_incorrect_arg_type(profiler):
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)


# @pytest.mark.skipif(not _PYTORCH_GREATER_EQUAL_1_7_0, reason='test needs PyTorch 1.7+')
def test_pytorch_profiler(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

model = TestModel()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
profiler='pytorch'
)

trainer.fit(model)
tchaton marked this conversation as resolved.
Show resolved Hide resolved