Skip to content

Commit

Permalink
[bugfix] Resolve PyTorch Profiling for Manual Optimization (#9316)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
2 people authored and awaelchli committed Sep 7, 2021
1 parent 185c4fd commit 130bc06
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 64 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))


- Fixed PyTorch Profiler not enabled for manual optimization ([#9316](https://github.com/PyTorchLightning/pytorch-lightning/pull/9316))


- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))

## [1.4.5] - 2021-08-31
Expand Down
29 changes: 24 additions & 5 deletions pl_examples/basic_examples/profiler_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from pl_examples import _DATASETS_PATH, cli_lightning_logo
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.profiler.pytorch import PyTorchProfiler
from pytorch_lightning.utilities.cli import LightningCLI

DEFAULT_CMD_LINE = (
Expand All @@ -43,18 +44,34 @@


class ModelToProfile(LightningModule):
def __init__(self, name: str = "resnet50"):
def __init__(self, name: str = "resnet18", automatic_optimization: bool = True):
super().__init__()
self.model = getattr(models, name)(pretrained=True)
self.criterion = torch.nn.CrossEntropyLoss()
self.automatic_optimization = automatic_optimization
self.training_step = (
self.automatic_optimization_training_step
if automatic_optimization
else self.manual_optimization_training_step
)

def training_step(self, batch, batch_idx):
def automatic_optimization_training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("train_loss", loss)
return loss

def manual_optimization_training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("train_loss", loss)
self.manual_backward(loss)
opt.step()

def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
Expand All @@ -77,18 +94,20 @@ def train_dataloader(self, *args, **kwargs):
trainset = torchvision.datasets.CIFAR10(
root=_DATASETS_PATH, train=True, download=True, transform=self.transform
)
return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0)
return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)

def val_dataloader(self, *args, **kwargs):
valset = torchvision.datasets.CIFAR10(root=_DATASETS_PATH, train=False, download=True, transform=self.transform)
return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0)
return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)


def cli_main():
if len(sys.argv) == 1:
sys.argv += DEFAULT_CMD_LINE

LightningCLI(ModelToProfile, CIFAR10DataModule)
LightningCLI(
ModelToProfile, CIFAR10DataModule, save_config_overwrite=True, trainer_defaults={"profiler": PyTorchProfiler()}
)


if __name__ == "__main__":
Expand Down
69 changes: 41 additions & 28 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
import logging
import os
from functools import partial
from functools import lru_cache, partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union

Expand All @@ -24,9 +24,10 @@
from torch.autograd.profiler import record_function

from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from torch.autograd.profiler import EventList
Expand All @@ -38,6 +39,7 @@
from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler

log = logging.getLogger(__name__)
warning_cache = WarningCache()

_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]

Expand Down Expand Up @@ -116,6 +118,7 @@ def pre_step(self, current_action: str) -> None:
self._current_action = current_action

def reset(self):
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
self._num_optimizer_step_and_closure = 0
self._num_validation_step = 0
self._num_test_step = 0
Expand All @@ -128,9 +131,15 @@ def reset(self):
self._current_action: Optional[str] = None
self._start_action_name: Optional[str] = None

@property
def is_training(self) -> bool:
return self._current_action is not None and (
self._current_action.startswith("optimizer_step_and_closure_") or self._current_action == "training_step"
)

@property
def num_step(self) -> int:
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
if self.is_training:
return self._num_optimizer_step_and_closure
if self._current_action == "validation_step":
return self._num_validation_step
Expand All @@ -141,7 +150,7 @@ def num_step(self) -> int:
return 0

def _step(self) -> None:
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
if self.is_training:
self._num_optimizer_step_and_closure += 1
elif self._current_action == "validation_step":
if self._start_action_name == "on_fit_start":
Expand All @@ -156,7 +165,7 @@ def _step(self) -> None:

@property
def has_finished(self) -> bool:
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
if self.is_training:
return self._optimizer_step_and_closure_reached_end
if self._current_action == "validation_step":
return self._validation_step_reached_end
Expand All @@ -172,9 +181,9 @@ def __call__(self, num_step: int) -> "ProfilerAction":
return ProfilerAction.NONE

self._step()
action = self._schedule(self.num_step)
action = self._schedule(max(self.num_step, 0))
if action == ProfilerAction.RECORD_AND_SAVE:
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
if self.is_training:
self._optimizer_step_and_closure_reached_end = True
elif self._current_action == "validation_step":
self._validation_step_reached_end = True
Expand All @@ -196,7 +205,7 @@ class PyTorchProfiler(BaseProfiler):
"predict_step",
}
RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_"
STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"}
STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"}
STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_"
AVAILABLE_SORT_KEYS = {
"cpu_time",
Expand Down Expand Up @@ -320,6 +329,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
raise MisconfigurationException(
f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}"
)
self._default_schedule()
schedule = schedule if has_schedule else self._default_schedule()
self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule
self._profiler_kwargs["schedule"] = self._schedule
Expand All @@ -331,28 +341,13 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph
self._profiler_kwargs["with_stack"] = with_stack

def __deprecation_check(
self, profiled_functions: Optional[List[str]], record_functions: Optional[Set[str]]
) -> Set[str]:
if record_functions is None:
record_functions = set()

if profiled_functions is not None:
rank_zero_deprecation(
"`PyTorchProfiler.profiled_functions` has been renamed to"
" `record_functions` in v1.3 and will be removed in v1.5"
)
if not record_functions:
record_functions |= set(profiled_functions)
else:
raise MisconfigurationException(
"You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`."
" Please use only the later."
)

return record_functions
def _should_override_schedule(self) -> bool:
return (self._lightning_module is not None and self._lightning_module.trainer.limit_train_batches < 5) and (
self._schedule is not None and self._schedule._schedule == self._default_schedule()
)

@staticmethod
@lru_cache(1)
def _default_schedule() -> Optional[callable]:
if _KINETO_AVAILABLE:
# Those schedule defaults allow the profiling overhead to be negligible over training time.
Expand Down Expand Up @@ -393,11 +388,18 @@ def start(self, action_name: str) -> None:
if self._register is not None:
self._register.__enter__()

if self._lightning_module is not None:
# when the model is used in automatic optimization,
# we use `optimizer_step_and_closure` to step the model.
if self._lightning_module.automatic_optimization and "training_step" in self.STEP_FUNCTIONS:
self.STEP_FUNCTIONS.remove("training_step")

if (
self.profiler is not None
and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX))
and action_name not in self._recording_map
):

recording = record_function(action_name)
recording.__enter__()
self._recording_map[action_name] = recording
Expand All @@ -413,6 +415,17 @@ def stop(self, action_name: str) -> None:
if self.profiler is not None and (
action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
):

# the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`.
# otherwise, this will raise a `segmentation fault`.
if self._should_override_schedule():
warning_cache.warn(
"The PyTorch Profiler default schedule will be overridden as there is not enough "
"steps to properly record traces."
)
self._schedule = None
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn

if self._schedule is not None:
self._schedule.pre_step(action_name)

Expand Down
7 changes: 6 additions & 1 deletion tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401
from tests.helpers.boring_model import ( # noqa: F401
BoringDataModule,
BoringModel,
ManualOptimBoringModel,
RandomDataset,
)
from tests.helpers.datasets import TrialMNIST # noqa: F401
15 changes: 15 additions & 0 deletions tests/helpers/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,18 @@ def test_dataloader(self):

def predict_dataloader(self):
return DataLoader(self.random_predict)


class ManualOptimBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
self.manual_backward(loss)
opt.step()
return loss
58 changes: 28 additions & 30 deletions tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers import BoringModel, ManualOptimBoringModel
from tests.helpers.runif import RunIf

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
Expand Down Expand Up @@ -309,50 +309,48 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
assert any(f"{local_rank}-validation_step" in f for f in files)


def test_pytorch_profiler_trainer_test(tmpdir):
@pytest.mark.parametrize("fast_dev_run", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("boring_model_cls", [ManualOptimBoringModel, BoringModel])
def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
"""Ensure that the profiler can be given to the trainer and test step are properly recorded."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
trainer.test(model)
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile")
model = boring_model_cls()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler)
trainer.fit(model)

assert sum(e.name == "test_step" for e in pytorch_profiler.function_events)
assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)

path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt"
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")

if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"test-{pytorch_profiler.filename}" in f for f in files)
path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt"
assert any(f"fit-{pytorch_profiler.filename}" in f for f in files)
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")


def test_pytorch_profiler_trainer_predict(tmpdir):
"""Ensure that the profiler can be given to the trainer and predict function are properly recorded."""
@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")])
@pytest.mark.parametrize("boring_model_cls", [BoringModel, ManualOptimBoringModel])
def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
"""Ensure that the profiler can be given to the trainer and test step are properly recorded."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
model = BoringModel()
model = boring_model_cls()
model.predict_dataloader = model.train_dataloader
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_predict_batches=2, profiler=pytorch_profiler)
trainer.predict(model)

assert sum(e.name == "predict_step" for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")


def test_pytorch_profiler_trainer_validate(tmpdir):
"""Ensure that the profiler can be given to the trainer and validate function are properly recorded."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=2, profiler=pytorch_profiler)
trainer.validate(model)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
getattr(trainer, fn)(model)

assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)
assert sum(e.name == f"{step_name}_step" for e in pytorch_profiler.function_events)

path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt"
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")

if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files)
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")


def test_pytorch_profiler_nested(tmpdir):
"""Ensure that the profiler handles nested context"""
Expand Down Expand Up @@ -467,7 +465,7 @@ def on_fit_end(self, trainer, *args, **kwargs) -> None:

profiler = cls(dirpath=tmpdir, filename="profiler")
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()])
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1, profiler=profiler, callbacks=[TestCallback()])
trainer.fit(model)

assert profiler._output_file is None
Expand Down

0 comments on commit 130bc06

Please sign in to comment.