From 810b4450979b8d0e3256fe385f8e3e5eeb1a24ab Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 15 Sep 2020 06:02:42 -0400 Subject: [PATCH] ref: apex plugin (#3502) * ref: apex plugin * ref: apex plugin * ref: apex plugin --- docs/source/index.rst | 1 + .../accelerators/ddp2_backend.py | 7 +--- .../accelerators/ddp_base_backend.py | 12 ++---- pytorch_lightning/accelerators/dp_backend.py | 6 ++- pytorch_lightning/accelerators/gpu_backend.py | 17 ++------- .../accelerators/horovod_backend.py | 13 ++----- pytorch_lightning/plugins/__init__.py | 0 pytorch_lightning/plugins/apex.py | 38 +++++++++++++++++++ 8 files changed, 57 insertions(+), 37 deletions(-) create mode 100644 pytorch_lightning/plugins/__init__.py create mode 100644 pytorch_lightning/plugins/apex.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 90a69a447c169..caaa43de5354d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -141,3 +141,4 @@ Indices and tables api/pytorch_lightning.trainer api/pytorch_lightning.utilities api/pytorch_lightning.tuner + api/pytorch_lightning.plugins diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 291fbeb2b2e9b..7414535350f64 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -22,6 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result from pytorch_lightning.accelerators.ddp_base_backend import DDPBase +from pytorch_lightning.plugins.apex import ApexPlugin try: from hydra.utils import to_absolute_path, get_original_cwd @@ -31,17 +32,13 @@ else: HYDRA_AVAILABLE = True -try: - from apex import amp -except ImportError: - amp = None - class DDP2Backend(DDPBase): def __init__(self, trainer): super().__init__(trainer) self.task_idx = None + self.precision_backend = None def setup(self, model): self._resolve_task_idx() diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py index ff51fbe786dfb..10e5cab3dcbb9 100644 --- a/pytorch_lightning/accelerators/ddp_base_backend.py +++ b/pytorch_lightning/accelerators/ddp_base_backend.py @@ -22,6 +22,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only from pytorch_lightning import _logger as log +from pytorch_lightning.plugins.apex import ApexPlugin try: from hydra.utils import to_absolute_path, get_original_cwd @@ -31,16 +32,12 @@ else: HYDRA_AVAILABLE = True -try: - from apex import amp -except ImportError: - amp = None - class DDPBase(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.precision_backend = None def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: @@ -155,9 +152,8 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs # AMP - # run through amp wrapper before going to distributed DP if self.trainer.amp_backend == AMPType.APEX: - model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) - self.trainer.optimizers = optimizers - self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) + self.precision_backend = ApexPlugin(self.trainer) + model, optimizers = self.precision_backend._init(model) # device ids change depending on the DDP setup device_ids = self.get_device_ids() diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index ed7573e418a92..09a373c3d4ef0 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -20,6 +20,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning.plugins.apex import ApexPlugin try: from apex import amp @@ -32,6 +33,7 @@ class DataParallelBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.model_autocast_original_forward = None + self.precision_backend = None def setup(self, model): # call setup after the ddp process has connected @@ -89,8 +91,8 @@ def __init_nvidia_apex(self, model): f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' f' We recommend you switch to ddp if you want to use amp') else: - model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) - self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers) + self.precision_backend = ApexPlugin(self.trainer) + model, optimizers = self.precision_backend._init(model) return model diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index eeff8529c1b77..ec3e84e840b02 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -13,14 +13,9 @@ # limitations under the License. import torch -from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.accelerators.base_backend import Accelerator - -try: - from apex import amp -except ImportError: - amp = None +from pytorch_lightning.plugins.apex import ApexPlugin class GPUBackend(Accelerator): @@ -28,6 +23,7 @@ class GPUBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.precision_backend = None def setup(self, model): @@ -45,7 +41,8 @@ def setup(self, model): self.trainer.optimizer_frequencies = optimizer_frequencies if self.trainer.amp_backend == AMPType.APEX: - model = self._setup_nvidia_apex(model) + self.precision_backend = ApexPlugin(self.trainer) + model, optimizers = self.precision_backend._init(model) self.trainer.model = model @@ -117,9 +114,3 @@ def to_device(self, batch): # be referenced from and if there are multiple optimizers the batch will # wind up copying it to the same device repeatedly. return self.batch_to_device(batch, gpu_id) - - def _setup_nvidia_apex(self, model: LightningModule): - model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) - self.trainer.optimizers = optimizers - self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) - return model diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index 6428ff86dc222..76887a6f87214 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -18,12 +18,7 @@ from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.utilities.distributed import rank_zero_only from torch.optim.lr_scheduler import _LRScheduler - -try: - from apex import amp -except ImportError: - amp = None - +from pytorch_lightning.plugins.apex import ApexPlugin try: import horovod.torch as hvd @@ -38,6 +33,7 @@ class HorovodBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.precision_backend = None def setup(self, model): # call setup after the ddp process has connected @@ -88,9 +84,8 @@ def filter_named_parameters(model, optimizer): ] if self.trainer.amp_backend == AMPType.APEX: - model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) - self.trainer.optimizers = optimizers - self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) + self.precision_backend = ApexPlugin(self.trainer) + model, optimizers = self.precision_backend._init(model) # Update logger rank info from Horovod to avoid race conditions from different ranks # creating directories / writing files in the same locations. diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py new file mode 100644 index 0000000000000..27ccff5592520 --- /dev/null +++ b/pytorch_lightning/plugins/apex.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from apex import amp +except ImportError: + amp = None + + +class ApexPlugin: + + def __init__(self, trainer): + self.trainer = trainer + + def _init(self, model): + model, optimizers = self.configure_apex(model, self.trainer.optimizers, self.trainer.amp_level) + self.trainer.optimizers = optimizers + self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) + return model, optimizers + + def configure_apex(self, model, optimizers, amp_level): + model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) + return model, optimizers + + def training_step(self, fx, args): + output = fx(args) + return output