diff --git a/CHANGELOG.md b/CHANGELOG.md index 94971f3cde2ff..acb620a32d0b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267)) +- Fixed PyTorch Profiler not enabled for manual optimization ([#9316](https://github.com/PyTorchLightning/pytorch-lightning/pull/9316)) + + - Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125)) ## [1.4.5] - 2021-08-31 diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py index 99346a91c823c..ab12fb623ecc3 100644 --- a/pl_examples/basic_examples/profiler_example.py +++ b/pl_examples/basic_examples/profiler_example.py @@ -31,6 +31,7 @@ from pl_examples import _DATASETS_PATH, cli_lightning_logo from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning.profiler.pytorch import PyTorchProfiler from pytorch_lightning.utilities.cli import LightningCLI DEFAULT_CMD_LINE = ( @@ -43,18 +44,34 @@ class ModelToProfile(LightningModule): - def __init__(self, name: str = "resnet50"): + def __init__(self, name: str = "resnet18", automatic_optimization: bool = True): super().__init__() self.model = getattr(models, name)(pretrained=True) self.criterion = torch.nn.CrossEntropyLoss() + self.automatic_optimization = automatic_optimization + self.training_step = ( + self.automatic_optimization_training_step + if automatic_optimization + else self.manual_optimization_training_step + ) - def training_step(self, batch, batch_idx): + def automatic_optimization_training_step(self, batch, batch_idx): inputs, labels = batch outputs = self.model(inputs) loss = self.criterion(outputs, labels) self.log("train_loss", loss) return loss + def manual_optimization_training_step(self, batch, batch_idx): + opt = self.optimizers() + opt.zero_grad() + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("train_loss", loss) + self.manual_backward(loss) + opt.step() + def validation_step(self, batch, batch_idx): inputs, labels = batch outputs = self.model(inputs) @@ -77,18 +94,20 @@ def train_dataloader(self, *args, **kwargs): trainset = torchvision.datasets.CIFAR10( root=_DATASETS_PATH, train=True, download=True, transform=self.transform ) - return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0) def val_dataloader(self, *args, **kwargs): valset = torchvision.datasets.CIFAR10(root=_DATASETS_PATH, train=False, download=True, transform=self.transform) - return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) + return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0) def cli_main(): if len(sys.argv) == 1: sys.argv += DEFAULT_CMD_LINE - LightningCLI(ModelToProfile, CIFAR10DataModule) + LightningCLI( + ModelToProfile, CIFAR10DataModule, save_config_overwrite=True, trainer_defaults={"profiler": PyTorchProfiler()} + ) if __name__ == "__main__": diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 963b800795236..ddf9ebd1644fb 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -15,7 +15,7 @@ import inspect import logging import os -from functools import partial +from functools import lru_cache, partial from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union @@ -24,9 +24,10 @@ from torch.autograd.profiler import record_function from pytorch_lightning.profiler.base import BaseProfiler -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE +from pytorch_lightning.utilities.warnings import WarningCache if TYPE_CHECKING: from torch.autograd.profiler import EventList @@ -38,6 +39,7 @@ from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler log = logging.getLogger(__name__) +warning_cache = WarningCache() _PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] @@ -116,6 +118,7 @@ def pre_step(self, current_action: str) -> None: self._current_action = current_action def reset(self): + # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise. self._num_optimizer_step_and_closure = 0 self._num_validation_step = 0 self._num_test_step = 0 @@ -128,9 +131,15 @@ def reset(self): self._current_action: Optional[str] = None self._start_action_name: Optional[str] = None + @property + def is_training(self) -> bool: + return self._current_action is not None and ( + self._current_action.startswith("optimizer_step_and_closure_") or self._current_action == "training_step" + ) + @property def num_step(self) -> int: - if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"): + if self.is_training: return self._num_optimizer_step_and_closure if self._current_action == "validation_step": return self._num_validation_step @@ -141,7 +150,7 @@ def num_step(self) -> int: return 0 def _step(self) -> None: - if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"): + if self.is_training: self._num_optimizer_step_and_closure += 1 elif self._current_action == "validation_step": if self._start_action_name == "on_fit_start": @@ -156,7 +165,7 @@ def _step(self) -> None: @property def has_finished(self) -> bool: - if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"): + if self.is_training: return self._optimizer_step_and_closure_reached_end if self._current_action == "validation_step": return self._validation_step_reached_end @@ -172,9 +181,9 @@ def __call__(self, num_step: int) -> "ProfilerAction": return ProfilerAction.NONE self._step() - action = self._schedule(self.num_step) + action = self._schedule(max(self.num_step, 0)) if action == ProfilerAction.RECORD_AND_SAVE: - if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"): + if self.is_training: self._optimizer_step_and_closure_reached_end = True elif self._current_action == "validation_step": self._validation_step_reached_end = True @@ -196,7 +205,7 @@ class PyTorchProfiler(BaseProfiler): "predict_step", } RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_" - STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"} + STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"} STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_" AVAILABLE_SORT_KEYS = { "cpu_time", @@ -320,6 +329,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None: raise MisconfigurationException( f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" ) + self._default_schedule() schedule = schedule if has_schedule else self._default_schedule() self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule self._profiler_kwargs["schedule"] = self._schedule @@ -331,28 +341,13 @@ def _init_kineto(self, profiler_kwargs: Any) -> None: with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph self._profiler_kwargs["with_stack"] = with_stack - def __deprecation_check( - self, profiled_functions: Optional[List[str]], record_functions: Optional[Set[str]] - ) -> Set[str]: - if record_functions is None: - record_functions = set() - - if profiled_functions is not None: - rank_zero_deprecation( - "`PyTorchProfiler.profiled_functions` has been renamed to" - " `record_functions` in v1.3 and will be removed in v1.5" - ) - if not record_functions: - record_functions |= set(profiled_functions) - else: - raise MisconfigurationException( - "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." - " Please use only the later." - ) - - return record_functions + def _should_override_schedule(self) -> bool: + return (self._lightning_module is not None and self._lightning_module.trainer.limit_train_batches < 5) and ( + self._schedule is not None and self._schedule._schedule == self._default_schedule() + ) @staticmethod + @lru_cache(1) def _default_schedule() -> Optional[callable]: if _KINETO_AVAILABLE: # Those schedule defaults allow the profiling overhead to be negligible over training time. @@ -393,11 +388,18 @@ def start(self, action_name: str) -> None: if self._register is not None: self._register.__enter__() + if self._lightning_module is not None: + # when the model is used in automatic optimization, + # we use `optimizer_step_and_closure` to step the model. + if self._lightning_module.automatic_optimization and "training_step" in self.STEP_FUNCTIONS: + self.STEP_FUNCTIONS.remove("training_step") + if ( self.profiler is not None and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX)) and action_name not in self._recording_map ): + recording = record_function(action_name) recording.__enter__() self._recording_map[action_name] = recording @@ -413,6 +415,17 @@ def stop(self, action_name: str) -> None: if self.profiler is not None and ( action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX) ): + + # the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`. + # otherwise, this will raise a `segmentation fault`. + if self._should_override_schedule(): + warning_cache.warn( + "The PyTorch Profiler default schedule will be overridden as there is not enough " + "steps to properly record traces." + ) + self._schedule = None + self.profiler.schedule = torch.profiler.profiler._default_schedule_fn + if self._schedule is not None: self._schedule.pre_step(action_name) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index e6fa5cfa70795..13b2e588d6b8d 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,2 +1,7 @@ -from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401 +from tests.helpers.boring_model import ( # noqa: F401 + BoringDataModule, + BoringModel, + ManualOptimBoringModel, + RandomDataset, +) from tests.helpers.datasets import TrialMNIST # noqa: F401 diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index d20cb1287e326..8b0f6e44d06c7 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -184,3 +184,18 @@ def test_dataloader(self): def predict_dataloader(self): return DataLoader(self.random_predict) + + +class ManualOptimBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + output = self(batch) + loss = self.loss(batch, output) + opt.zero_grad() + self.manual_backward(loss) + opt.step() + return loss diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 2145ab83e9cdb..17d8e6be4eab7 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -29,7 +29,7 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE -from tests.helpers import BoringModel +from tests.helpers import BoringModel, ManualOptimBoringModel from tests.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -309,50 +309,48 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): assert any(f"{local_rank}-validation_step" in f for f in files) -def test_pytorch_profiler_trainer_test(tmpdir): +@pytest.mark.parametrize("fast_dev_run", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("boring_model_cls", [ManualOptimBoringModel, BoringModel]) +def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir): """Ensure that the profiler can be given to the trainer and test step are properly recorded.""" - pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler) - trainer.test(model) + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile") + model = boring_model_cls() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler) + trainer.fit(model) - assert sum(e.name == "test_step" for e in pytorch_profiler.function_events) + assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events) - path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" + path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") if _KINETO_AVAILABLE: files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json")) - assert any(f"test-{pytorch_profiler.filename}" in f for f in files) - path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" + assert any(f"fit-{pytorch_profiler.filename}" in f for f in files) + path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") -def test_pytorch_profiler_trainer_predict(tmpdir): - """Ensure that the profiler can be given to the trainer and predict function are properly recorded.""" +@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")]) +@pytest.mark.parametrize("boring_model_cls", [BoringModel, ManualOptimBoringModel]) +def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir): + """Ensure that the profiler can be given to the trainer and test step are properly recorded.""" pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) - model = BoringModel() + model = boring_model_cls() model.predict_dataloader = model.train_dataloader - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_predict_batches=2, profiler=pytorch_profiler) - trainer.predict(model) - - assert sum(e.name == "predict_step" for e in pytorch_profiler.function_events) - path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" - assert path.read_text("utf-8") - - -def test_pytorch_profiler_trainer_validate(tmpdir): - """Ensure that the profiler can be given to the trainer and validate function are properly recorded.""" - pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=2, profiler=pytorch_profiler) - trainer.validate(model) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler) + getattr(trainer, fn)(model) - assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events) + assert sum(e.name == f"{step_name}_step" for e in pytorch_profiler.function_events) - path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" + path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") + if _KINETO_AVAILABLE: + files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json")) + assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files) + path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + def test_pytorch_profiler_nested(tmpdir): """Ensure that the profiler handles nested context""" @@ -467,7 +465,7 @@ def on_fit_end(self, trainer, *args, **kwargs) -> None: profiler = cls(dirpath=tmpdir, filename="profiler") model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) assert profiler._output_file is None