From 06acca3456a0a936dee258e492be0eb98114dd69 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Sun, 12 Jul 2020 15:06:27 -0700 Subject: [PATCH] save apex scaler states --- pytorch_lightning/trainer/training_io.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 90f36ecab59c3..666dfbb2588b7 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -115,6 +115,13 @@ else: XLA_AVAILABLE = True +try: + from apex import amp +except ImportError: + APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True + try: import horovod.torch as hvd except (ModuleNotFoundError, ImportError): @@ -317,6 +324,8 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # restore amp scaling if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint: self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint: + amp.load_state_dict(checkpoint['amp_scaling_state']) # load training state (affects trainer only) self.restore_training_state(checkpoint) @@ -368,6 +377,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # save native amp scaling if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() + elif self.use_amp and not NATIVE_AMP_AVALAIBLE: + checkpoint['amp_scaling_state'] = amp.state_dict() # add the module_arguments and state_dict from the model model = self.get_model() @@ -523,6 +534,8 @@ def hpc_load(self, folderpath, on_gpu): # restore amp scaling if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint: self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint: + amp.load_state_dict(checkpoint['amp_scaling_state']) if self.root_gpu is not None: model.cuda(self.root_gpu)