Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,65 @@ def _maybe_log_save_evaluate(
timer.step()
self.log_evaluate_save_time += timer.last_duration

def _save_checkpoint(self, model, trial):
# Copied from https://github.com/huggingface/transformers/blob/v4.51-release/src/transformers/trainer.py#L3187
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model

# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

if self.hp_search_backend is None and trial is None:
self.store_flos()

run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)

# NOTE(pbielak): In a multi-card scenario, the model saving is done by the main process (rank zero),
# whereas all other ranks continue processing. When checking for the `best_checkpoint_dir` below,
# a race condition occurs. This barrier forces other processes to wait till the model is saved.
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
Comment thread
regisss marked this conversation as resolved.
torch.distributed.barrier()

if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)

if os.path.exists(best_checkpoint_dir):
self.state.best_model_checkpoint = best_checkpoint_dir

if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
self._save_scaler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)

# Save the Trainer state
if self.args.should_save:
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
for cb in [
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]:
cb_name = cb.__class__.__name__
cb_state = cb.state()
if isinstance(self.state.stateful_callbacks[cb_name], list):
self.state.stateful_callbacks[cb_name].append(cb_state)
else:
self.state.stateful_callbacks[cb_name] = cb_state
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
# Solely rely on numerical checkpoint id for rotation.
# mtime is not reliable especially on some fuse fs in cloud environments.
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)

def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
Expand Down
Loading