Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):

# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

# prepare using `accelerator` prepare
Expand Down Expand Up @@ -1179,22 +1179,16 @@ def _save_checkpoint(self, model, trial, metrics=None):
# This block is exectuted by the main process only
optim_dict = self.optimizer.state_dict()
scheduler_dict = self.lr_scheduler.state_dict()
if self.do_grad_scaling:
scaler_dict = self.scaler.state_dict()
if self.args.use_habana:
# Move the state dict from HPU to CPU before saving
optim_dict = to_device_dtype(optim_dict, target_device=torch.device("cpu"))
scheduler_dict = to_device_dtype(scheduler_dict, target_device=torch.device("cpu"))
if self.do_grad_scaling:
scaler_dict = to_device_dtype(scaler_dict, target_device=torch.device("cpu"))
torch.save(optim_dict, os.path.join(output_dir, OPTIMIZER_NAME))

# Save SCHEDULER & SCALER
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(scheduler_dict, os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(scaler_dict, os.path.join(output_dir, SCALER_NAME))

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
Expand Down Expand Up @@ -1279,16 +1273,9 @@ def _load_optimizer_and_scheduler(self, checkpoint):
)
reissue_pt_warnings(caught_warnings)

if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
torch.load(os.path.join(checkpoint, SCALER_NAME), map_location=map_location)
)

# Move optimizer state to HPU
if self.args.use_habana:
to_device_dtype(self.optimizer.state.values(), target_device=torch.device("hpu"))
if self.do_grad_scaling:
to_device_dtype(self.scaler.state.values(), target_device=torch.device("hpu"))

def log(self, logs: Dict[str, float]) -> None:
"""
Expand Down Expand Up @@ -1374,10 +1361,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
if self.args.pipelining_fwd_bwd:
self.htcore.mark_step()

if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
self.accelerator.backward(loss)
self.accelerator.backward(loss)

return loss.detach() / self.args.gradient_accumulation_steps

Expand Down
1 change: 0 additions & 1 deletion optimum/habana/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def prediction_step(
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)

# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
Expand Down
3 changes: 0 additions & 3 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
"fp16_opt_level",
"fsdp",
"mp_parameters",
"sharded_ddp",
"tf32",
"tpu_metrics_debug",
"tpu_num_cores",
Expand Down Expand Up @@ -304,8 +303,6 @@ def __post_init__(self):
raise ValueError("TPUs are not supported by optimum-habana.")
if self.mp_parameters:
raise ValueError("--mp_parameters is not supported by optimum-habana.")
if self.sharded_ddp:
raise ValueError("--sharded_ddp is not supported by optimum-habana.")
if self.tf32:
raise ValueError("--tf32 is not supported by optimum-habana.")

Expand Down