Skip to content

Commit

Permalink
refactor: move [flash checkpoint manager] to callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Jan 25, 2025
1 parent 94246f6 commit 0297880
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 104 deletions.
90 changes: 4 additions & 86 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import collections
import contextlib
import copy
import inspect
import json
import math
Expand Down Expand Up @@ -92,7 +91,6 @@
from ..transformers.model_utils import (
PretrainedModel,
_add_variant,
get_parameter_dtype,
load_sharded_checkpoint,
unwrap_model,
)
Expand Down Expand Up @@ -405,12 +403,7 @@ def __init__(
assert not self.args.save_rng_states, "save_rng_states is not supported when using flash save mode"

# init attributes for flash save mode
self.manipulated_state_dict = None
self.manipulated_config_to_save = None
self.manipulated_weight_suffix = None
self.model_meta = None
self.flash_checkpoint_manager = None
self.user_file_list = []

if self.args.ordered_save_group_size > 0:
logger.info(f"using save in order, its group size is {self.args.ordered_save_group_size}")
Expand Down Expand Up @@ -730,67 +723,17 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
ema_coef=self.args.flash_save_ema_coef,
)
_callback = FlashCheckpointCallback(self.flash_checkpoint_manager)
self.add_callback(_callback)
_callback = FlashCheckpointCallback(
self.args, self.flash_checkpoint_manager, self.runtime_timer, self.sharding_io
)
self.add_callback(_callback)
if resume_from_checkpoint is not None:
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
logger.info(f"FC EMA load from {path}")
self.flash_checkpoint_manager.set_ema_state_dict(path)
logger.info("Create flash checkpoint manager done.")

def maybe_update_flash_checkpoint_worker(self):
if self.optimizer.fused_buffer_version == self.flash_checkpoint_manager.cache_version:
return

logger.info("Flash checkpoint workers need upgrade.")
self._cache_meta_for_sharded_save()
param_mappings, ipc_meta_mappings = get_fused_param_mappings(self.optimizer, self.manipulated_state_dict)
optimizer_states_meta = (
self.optimizer.fused_states_accumulators_meta,
self.optimizer.fused_states_master_weights_meta,
None,
self.optimizer.fused_states_buffer_ipc_meta,
)
model_states_meta = (param_mappings, ipc_meta_mappings)
optimizer_states_name_path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
model_states_name_path = _add_variant(PADDLE_WEIGHTS_NAME, self.manipulated_weight_suffix)

dynamic_objecs = {}
dynamic_objecs["optimizer_states_meta"] = optimizer_states_meta
dynamic_objecs["model_states_meta"] = model_states_meta
dynamic_objecs["optimizer_states_name_path"] = optimizer_states_name_path
dynamic_objecs["model_states_name_path"] = model_states_name_path

static_objects = {}
static_objects["model_config"] = self.manipulated_config_to_save
static_objects["training_args"] = self.args
static_objects["model_meta"] = self.model_meta
static_objects["user_file"] = self.user_file_list

self.flash_checkpoint_manager.update_flash_workers(
self.optimizer.fused_buffer_version, dynamic_objecs, static_objects
)

def _cache_meta_for_sharded_save(self):
logger.info("Start caching metas for sharded save...")
(
self.manipulated_state_dict,
self.manipulated_config_to_save,
self.manipulated_weight_suffix,
) = self.sharding_io.manipulate_state_dict_and_config(self.model, merge_tensor_parallel=False)
logger.info("Cache manipulated static dict done.")
if self.manipulated_config_to_save is None:
model_to_save = unwrap_model(self.model)
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
self.manipulated_config_to_save = copy.deepcopy(model_to_save.config)
self.manipulated_config_to_save.architectures = [model_to_save.__class__.__name__]
self.manipulated_config_to_save = self.manipulated_config_to_save.to_json_string(use_diff=True)
logger.info("Cache manipulated model config done")
self.model_meta = self.sharding_io.gather_distributed_model_meta()
logger.info("Cache distributed model meta done.")

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -1230,10 +1173,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
self.callback_handler.on_optimizer_begin(
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
)
if self.args.enable_flash_save_mode and self.flash_checkpoint_manager.current_worker is not None:
logger.info("Start syncing flash checkpoints")
self.flash_checkpoint_manager.sync_offload_status()
logger.info("Synced flash checkpoints.")
optimizer_was_run = True

if self.args.offload_optim:
Expand Down Expand Up @@ -2462,28 +2401,7 @@ def _ordered_save(self, state_dict, save_path):
paddle.save(state_dict, save_path)
dist.barrier(mp_group)

def _get_save_infos_based_on_steps(self, checkpoint_folder):
flash_checkpoint_dir = None
persistent_checkpoint_dir = None
if self.args.flash_save_steps > 0 and self.state.global_step % self.args.flash_save_steps == 0:
flash_checkpoint_dir = os.path.join(FLASH_DEVICE, checkpoint_folder)
if self.args.save_steps > 0 and self.state.global_step % self.args.save_steps == 0:
persistent_checkpoint_dir = os.path.join(self.args.output_dir, checkpoint_folder)
return (flash_checkpoint_dir, persistent_checkpoint_dir)

def _save_checkpoint_flash(self):
self.runtime_timer.start("checkpoint saving time")
self.maybe_update_flash_checkpoint_worker()
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
save_infos = self._get_save_infos_based_on_steps(checkpoint_folder)
non_cached_objects = (self.lr_scheduler.state_dict(), self.state)
self.flash_checkpoint_manager.get_idle_worker_for_saving(save_infos, non_cached_objects)
self.runtime_timer.stop()

def _save_checkpoint(self, model, metrics=None):
if self.args.enable_flash_save_mode:
self._save_checkpoint_flash()
return
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")

Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,10 @@ class TrainingArguments:
default=0,
metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"},
)
flash_ema_interval: Optional[int] = field(
default=1,
metadata={"help": "Interval between updating EMA parameters."},
)
save_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Save tokenizer to output_dir."},
Expand Down
Loading

0 comments on commit 0297880

Please sign in to comment.