From 55e1ddc6e43908eec76ea8eb1b110cca8b72a2cb Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 9 Oct 2023 07:08:35 +0000 Subject: [PATCH] Remove `sharded_ddp` --- optimum/habana/transformers/trainer.py | 20 ++----------------- .../habana/transformers/trainer_seq2seq.py | 1 - optimum/habana/transformers/training_args.py | 3 --- 3 files changed, 2 insertions(+), 22 deletions(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index fb153a3ea5..739f53c6f4 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -634,7 +634,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as - # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False # prepare using `accelerator` prepare @@ -1179,22 +1179,16 @@ def _save_checkpoint(self, model, trial, metrics=None): # This block is exectuted by the main process only optim_dict = self.optimizer.state_dict() scheduler_dict = self.lr_scheduler.state_dict() - if self.do_grad_scaling: - scaler_dict = self.scaler.state_dict() if self.args.use_habana: # Move the state dict from HPU to CPU before saving optim_dict = to_device_dtype(optim_dict, target_device=torch.device("cpu")) scheduler_dict = to_device_dtype(scheduler_dict, target_device=torch.device("cpu")) - if self.do_grad_scaling: - scaler_dict = to_device_dtype(scaler_dict, target_device=torch.device("cpu")) torch.save(optim_dict, os.path.join(output_dir, OPTIMIZER_NAME)) # Save SCHEDULER & SCALER with warnings.catch_warnings(record=True) as caught_warnings: torch.save(scheduler_dict, os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) - if self.do_grad_scaling: - torch.save(scaler_dict, os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -1279,16 +1273,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) reissue_pt_warnings(caught_warnings) - if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - torch.load(os.path.join(checkpoint, SCALER_NAME), map_location=map_location) - ) - # Move optimizer state to HPU if self.args.use_habana: to_device_dtype(self.optimizer.state.values(), target_device=torch.device("hpu")) - if self.do_grad_scaling: - to_device_dtype(self.scaler.state.values(), target_device=torch.device("hpu")) def log(self, logs: Dict[str, float]) -> None: """ @@ -1374,10 +1361,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te if self.args.pipelining_fwd_bwd: self.htcore.mark_step() - if self.do_grad_scaling: - self.scaler.scale(loss).backward() - else: - self.accelerator.backward(loss) + self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps diff --git a/optimum/habana/transformers/trainer_seq2seq.py b/optimum/habana/transformers/trainer_seq2seq.py index 14155ce099..fb6a231d79 100644 --- a/optimum/habana/transformers/trainer_seq2seq.py +++ b/optimum/habana/transformers/trainer_seq2seq.py @@ -256,7 +256,6 @@ def prediction_step( has_labels = "labels" in inputs inputs = self._prepare_inputs(inputs) - # XXX: adapt synced_gpus for fairscale as well # Priority (handled in generate): # non-`None` gen_kwargs > model.generation_config > default GenerationConfig() if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 1f647f0d74..64f93c86d8 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -60,7 +60,6 @@ "fp16_opt_level", "fsdp", "mp_parameters", - "sharded_ddp", "tf32", "tpu_metrics_debug", "tpu_num_cores", @@ -304,8 +303,6 @@ def __post_init__(self): raise ValueError("TPUs are not supported by optimum-habana.") if self.mp_parameters: raise ValueError("--mp_parameters is not supported by optimum-habana.") - if self.sharded_ddp: - raise ValueError("--sharded_ddp is not supported by optimum-habana.") if self.tf32: raise ValueError("--tf32 is not supported by optimum-habana.")