diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index f42ea241a266..b80343339585 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -18,7 +18,6 @@ import collections import contextlib -import copy import inspect import json import math @@ -92,7 +91,6 @@ from ..transformers.model_utils import ( PretrainedModel, _add_variant, - get_parameter_dtype, load_sharded_checkpoint, unwrap_model, ) @@ -405,12 +403,7 @@ def __init__( assert not self.args.save_rng_states, "save_rng_states is not supported when using flash save mode" # init attributes for flash save mode - self.manipulated_state_dict = None - self.manipulated_config_to_save = None - self.manipulated_weight_suffix = None - self.model_meta = None self.flash_checkpoint_manager = None - self.user_file_list = [] if self.args.ordered_save_group_size > 0: logger.info(f"using save in order, its group size is {self.args.ordered_save_group_size}") @@ -730,8 +723,10 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, ema_coef=self.args.flash_save_ema_coef, ) - _callback = FlashCheckpointCallback(self.flash_checkpoint_manager) - self.add_callback(_callback) + _callback = FlashCheckpointCallback( + self.args, self.flash_checkpoint_manager, self.runtime_timer, self.sharding_io + ) + self.add_callback(_callback) if resume_from_checkpoint is not None: path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema") @@ -739,58 +734,6 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin self.flash_checkpoint_manager.set_ema_state_dict(path) logger.info("Create flash checkpoint manager done.") - def maybe_update_flash_checkpoint_worker(self): - if self.optimizer.fused_buffer_version == self.flash_checkpoint_manager.cache_version: - return - - logger.info("Flash checkpoint workers need upgrade.") - self._cache_meta_for_sharded_save() - param_mappings, ipc_meta_mappings = get_fused_param_mappings(self.optimizer, self.manipulated_state_dict) - optimizer_states_meta = ( - self.optimizer.fused_states_accumulators_meta, - self.optimizer.fused_states_master_weights_meta, - None, - self.optimizer.fused_states_buffer_ipc_meta, - ) - model_states_meta = (param_mappings, ipc_meta_mappings) - optimizer_states_name_path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) - model_states_name_path = _add_variant(PADDLE_WEIGHTS_NAME, self.manipulated_weight_suffix) - - dynamic_objecs = {} - dynamic_objecs["optimizer_states_meta"] = optimizer_states_meta - dynamic_objecs["model_states_meta"] = model_states_meta - dynamic_objecs["optimizer_states_name_path"] = optimizer_states_name_path - dynamic_objecs["model_states_name_path"] = model_states_name_path - - static_objects = {} - static_objects["model_config"] = self.manipulated_config_to_save - static_objects["training_args"] = self.args - static_objects["model_meta"] = self.model_meta - static_objects["user_file"] = self.user_file_list - - self.flash_checkpoint_manager.update_flash_workers( - self.optimizer.fused_buffer_version, dynamic_objecs, static_objects - ) - - def _cache_meta_for_sharded_save(self): - logger.info("Start caching metas for sharded save...") - ( - self.manipulated_state_dict, - self.manipulated_config_to_save, - self.manipulated_weight_suffix, - ) = self.sharding_io.manipulate_state_dict_and_config(self.model, merge_tensor_parallel=False) - logger.info("Cache manipulated static dict done.") - if self.manipulated_config_to_save is None: - model_to_save = unwrap_model(self.model) - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.dtype = str(dtype).split(".")[1] - self.manipulated_config_to_save = copy.deepcopy(model_to_save.config) - self.manipulated_config_to_save.architectures = [model_to_save.__class__.__name__] - self.manipulated_config_to_save = self.manipulated_config_to_save.to_json_string(use_diff=True) - logger.info("Cache manipulated model config done") - self.model_meta = self.sharding_io.gather_distributed_model_meta() - logger.info("Cache distributed model meta done.") - def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -1230,10 +1173,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None ) - if self.args.enable_flash_save_mode and self.flash_checkpoint_manager.current_worker is not None: - logger.info("Start syncing flash checkpoints") - self.flash_checkpoint_manager.sync_offload_status() - logger.info("Synced flash checkpoints.") optimizer_was_run = True if self.args.offload_optim: @@ -2462,28 +2401,7 @@ def _ordered_save(self, state_dict, save_path): paddle.save(state_dict, save_path) dist.barrier(mp_group) - def _get_save_infos_based_on_steps(self, checkpoint_folder): - flash_checkpoint_dir = None - persistent_checkpoint_dir = None - if self.args.flash_save_steps > 0 and self.state.global_step % self.args.flash_save_steps == 0: - flash_checkpoint_dir = os.path.join(FLASH_DEVICE, checkpoint_folder) - if self.args.save_steps > 0 and self.state.global_step % self.args.save_steps == 0: - persistent_checkpoint_dir = os.path.join(self.args.output_dir, checkpoint_folder) - return (flash_checkpoint_dir, persistent_checkpoint_dir) - - def _save_checkpoint_flash(self): - self.runtime_timer.start("checkpoint saving time") - self.maybe_update_flash_checkpoint_worker() - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - save_infos = self._get_save_infos_based_on_steps(checkpoint_folder) - non_cached_objects = (self.lr_scheduler.state_dict(), self.state) - self.flash_checkpoint_manager.get_idle_worker_for_saving(save_infos, non_cached_objects) - self.runtime_timer.stop() - def _save_checkpoint(self, model, metrics=None): - if self.args.enable_flash_save_mode: - self._save_checkpoint_flash() - return # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" self.runtime_timer.start("checkpoint saving time") diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 1e8cf7da2836..611b2b9540e4 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -880,6 +880,10 @@ class TrainingArguments: default=0, metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"}, ) + flash_ema_interval: Optional[int] = field( + default=1, + metadata={"help": "Interval between updating EMA parameters."}, + ) save_tokenizer: Optional[bool] = field( default=True, metadata={"help": "Save tokenizer to output_dir."}, diff --git a/paddlenlp/trainer/utils/flash_checkpoint.py b/paddlenlp/trainer/utils/flash_checkpoint.py index adb538f9c4af..89a57a22617c 100644 --- a/paddlenlp/trainer/utils/flash_checkpoint.py +++ b/paddlenlp/trainer/utils/flash_checkpoint.py @@ -13,10 +13,9 @@ # limitations under the License. import atexit +import copy import hashlib import json - -# import copy import multiprocessing import os import time @@ -34,16 +33,25 @@ from paddle.optimizer.fusion_utils import FusionStorageHelper from paddlenlp.trainer.trainer_callback import TrainerCallback +from paddlenlp.transformers.model_utils import ( + _add_variant, + get_parameter_dtype, + unwrap_model, +) from paddlenlp.transformers.utils import device_guard from paddlenlp.utils.env import ( CONFIG_NAME, MODEL_META_NAME, + PADDLE_OPTIMIZER_NAME, + PADDLE_WEIGHTS_NAME, + PREFIX_CHECKPOINT_DIR, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME, ) from paddlenlp.utils.fault_tolerance import FC_DUMP_ERROR, PC_DUMP_ERROR from paddlenlp.utils.log import logger +from paddlenlp.utils.pdc_sdk import FLASH_DEVICE def md5(tensor): @@ -62,7 +70,7 @@ class FCTaskType(Enum): PREPARE = 1 OFFLOAD = 2 FINISH = 3 - SET_EMA_STATE_DICT = 4 + SET_EMA_STATE_DICT = 5 class FCWorkerStatus(Enum): @@ -156,6 +164,7 @@ def ema_accumulate(self): """ # logger.info(f'[FC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}') # do update: ema = alpha * ema + (1-alpha) * model + logger.info("[FC EMA] start") cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice( self.master_min_offset, self.master_max_offset ).cpu() @@ -337,12 +346,125 @@ def restore_tensor_from_meta(self, tensor_meta): class FlashCheckpointCallback(TrainerCallback): - def __init__(self, flash_checkpoint_manager): + """ + call FlashCheckpointManager during training in following order: + + on_step_end: + * call get_idle_worker_for_saving, set manager.current_worker + * call maybe_update_flash_checkpoint_worker + + * on_substep_end(call `gradient_accumulate` times): call flash_checkpoint_pipeline_hook (in non-pp model) + * (when offload done, dump model) + on_optimizer_begin: call sync_offload_status, unset set manager.current_worker + """ + + def __init__(self, args, flash_checkpoint_manager, timer, sharding_io): self.manager = flash_checkpoint_manager + self.runtime_timer = timer + self.user_file_list = [] + self.manipulated_state_dict = None + self.manipulated_config_to_save = None + self.manipulated_weight_suffix = None + self.model_meta = None + self.sharding_io = sharding_io + assert ( + args.flash_save_steps % args.flash_ema_interval == 0 + ), f"flash_save_steps:{args.flash_save_steps} must be divisible by flash_ema_interval:{args.flash_ema_interval}" + assert ( + args.save_steps % args.flash_ema_interval == 0 + ), f"save_steps:{args.save_steps} must be divisible by flash_ema_interval:{args.flash_ema_interval}" + self.flash_ema_interval = args.flash_ema_interval + if args.flash_save_ema_coef: + assert args.flash_workers_num == 1, "[FC EMA] not support #worker > 1" def on_substep_end(self, args, state, control, **kwargs): self.manager.flash_checkpoint_pipeline_hook(0) + def on_optimizer_begin(self, args, state, control, **kwargs): + if args.enable_flash_save_mode and self.manager.current_worker is not None: + logger.info("Start syncing flash checkpoints") + self.manager.sync_offload_status() + logger.info("Synced flash checkpoints.") + + def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kwargs): + self.manager.flash_checkpoint_pipeline_hook(0) + logger.info( + f"check coef: {args.flash_save_ema_coef} {control.should_save}, {state.global_step}, {self.flash_ema_interval}" + ) + if not control.should_save: + if args.flash_save_ema_coef and state.global_step % self.flash_ema_interval == 0: + self.maybe_update_flash_checkpoint_worker(args, model, optimizer) + self.manager.get_idle_worker_for_saving() # prepare for dumping + else: + self.runtime_timer.start("checkpoint saving time") + self.maybe_update_flash_checkpoint_worker(args, model, optimizer) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + save_infos = self._get_save_infos_based_on_steps(state, args, checkpoint_folder) + non_cached_objects = (lr_scheduler.state_dict(), state) + self.manager.get_idle_worker_for_saving((save_infos, non_cached_objects)) + self.runtime_timer.stop() + control.should_save = False # avoid regular saving + + def _get_save_infos_based_on_steps(self, state, args, checkpoint_folder): + flash_checkpoint_dir = None + persistent_checkpoint_dir = None + if args.flash_save_steps > 0 and state.global_step % args.flash_save_steps == 0: + flash_checkpoint_dir = os.path.join(FLASH_DEVICE, checkpoint_folder) + if args.save_steps > 0 and state.global_step % args.save_steps == 0: + persistent_checkpoint_dir = os.path.join(args.output_dir, checkpoint_folder) + return (flash_checkpoint_dir, persistent_checkpoint_dir) + + def maybe_update_flash_checkpoint_worker(self, args, model, optimizer): + # logger.info(f'check should update :{optimizer.fused_buffer_version} vs {self.manager.cache_version}') + if optimizer.fused_buffer_version == self.manager.cache_version: + return + + logger.info("Flash checkpoint workers need upgrade.") + self._cache_meta_for_sharded_save(model) + param_mappings, ipc_meta_mappings = get_fused_param_mappings(optimizer, self.manipulated_state_dict) + optimizer_states_meta = ( + optimizer.fused_states_accumulators_meta, + optimizer.fused_states_master_weights_meta, + None, + optimizer.fused_states_buffer_ipc_meta, + ) + model_states_meta = (param_mappings, ipc_meta_mappings) + optimizer_states_name_path = _add_variant(PADDLE_OPTIMIZER_NAME, args.optimizer_name_suffix) + model_states_name_path = _add_variant(PADDLE_WEIGHTS_NAME, self.manipulated_weight_suffix) + + dynamic_objecs = {} + dynamic_objecs["optimizer_states_meta"] = optimizer_states_meta + dynamic_objecs["model_states_meta"] = model_states_meta + dynamic_objecs["optimizer_states_name_path"] = optimizer_states_name_path + dynamic_objecs["model_states_name_path"] = model_states_name_path + + static_objects = {} + static_objects["model_config"] = self.manipulated_config_to_save + static_objects["training_args"] = args + static_objects["model_meta"] = self.model_meta + static_objects["user_file"] = self.user_file_list + + self.manager.update_flash_workers(optimizer.fused_buffer_version, dynamic_objecs, static_objects) + + def _cache_meta_for_sharded_save(self, model): + logger.info("Start caching metas for sharded save...") + ( + self.manipulated_state_dict, + self.manipulated_config_to_save, + self.manipulated_weight_suffix, + ) = self.sharding_io.manipulate_state_dict_and_config(model, merge_tensor_parallel=False) + logger.info("Cache manipulated static dict done.") + if self.manipulated_config_to_save is None: + model_to_save = unwrap_model(model) + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.dtype = str(dtype).split(".")[1] + self.manipulated_config_to_save = copy.deepcopy(model_to_save.config) + self.manipulated_config_to_save.architectures = [model_to_save.__class__.__name__] + self.manipulated_config_to_save = self.manipulated_config_to_save.to_json_string(use_diff=True) + logger.info("Cache manipulated model config done") + self.model_meta = self.sharding_io.gather_distributed_model_meta() + logger.info("Cache distributed model meta done.") + class FlashCheckpointManager: def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None): @@ -355,7 +477,6 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef self.current_worker = None self.device_id = int(os.getenv("FLAGS_selected_gpus")) self.pipeline_hooks_steps = max(int(pipeline_hooks_capacity * capacity_usage), 1) - self.ema_coef = ema_coef logger.info( f"[FC manager] pipeline hooks capacity: {pipeline_hooks_capacity}; " f"pipeline hooks steps for offloading: {self.pipeline_hooks_steps} " @@ -412,7 +533,10 @@ def update_flash_workers(self, new_version, dynamic_objecs, static_objects): logger.info("[FC manager] update all flash workers done") self.ready_to_save = True - def get_idle_worker_for_saving(self, save_infos, non_cached_objects): + def get_idle_worker_for_saving(self, save_infos_and_non_cached_objects=None): + """ + if `save_infos_and_non_cached_objects` is None, do offload without dumping. + """ self.report_error_worker() assert self.current_worker is None, "[FC manager] current_worker must be None" found_worker = False @@ -424,12 +548,16 @@ def get_idle_worker_for_saving(self, save_infos, non_cached_objects): break if found_worker: break - logger.info("[FC manager] Waiting for idle worker...") + logger.info("[FC manager] Waiting for idle worker..., consider increse `save-step` or `global-batch-size`") time.sleep(1) - task = (FCTaskType.PREPARE, (save_infos, non_cached_objects)) - logger.info("[FC manager] before putting task for prepare") + task = (FCTaskType.PREPARE, save_infos_and_non_cached_objects) + logger.info( + f"[FC manager] before putting task for prepare, dumping={save_infos_and_non_cached_objects is not None}" + ) self.current_worker.task_queue.put(task) - logger.info("[FC manager] after putting task for prepare") + logger.info( + f"[FC manager] after putting task for prepare, dumping={save_infos_and_non_cached_objects is not None}" + ) def sync_offload_status(self): self.report_error_worker() @@ -542,13 +670,15 @@ def process_update_task(self, updates): self.version.value = version def process_prepare_task(self, prepares): - save_infos, non_cached_objects = prepares self.offloaded_numels = 0 self.status.value = FCWorkerStatus.OFFLOADING.value + if prepares is None: # when `prepares` is None, not dumping + return + save_infos, non_cached_objects = prepares self.flash_save_dir, self.persistent_save_dir = save_infos self.lr_scheduler, self.trainer_state = non_cached_objects - def process_offload_task(self): + def process_offload_task(self, dump): actual_offload_size = ( min(self.offloaded_numels + self.chunk_size_in_numel, self.all_numel) - self.offloaded_numels ) @@ -582,12 +712,12 @@ def process_offload_task(self): # continue to process dumping task at the last chunk if self.offloaded_numels == self.all_numel: - need_report_error = self.process_dump_task() - self.offloaded_numels = 0 - if need_report_error: - self.status.value = FCWorkerStatus.ERROR.value + if dump: + need_report_error = self.process_dump_task() else: - self.status.value = FCWorkerStatus.IDLE.value + need_report_error = False + self.offloaded_numels = 0 + self.status.value = FCWorkerStatus.ERROR.value if need_report_error else FCWorkerStatus.IDLE.value def process_dump_task(self): """ @@ -669,6 +799,7 @@ def run(self): paddle.set_device(f"gpu:{self.device_id}") logger.info(f"[FC worker{self.worker_id}] Worker{self.worker_id} started.") ema_ckpt_path = None + save_info_tuple = None # save dir... try: while True: task = self.task_queue.get() @@ -686,9 +817,10 @@ def run(self): self.flash_ema_processor.load_ema_state_dict(ema_ckpt_path) ema_ckpt_path = None elif task_type == FCTaskType.PREPARE: + save_info_tuple = task_body self.process_prepare_task(task_body) elif task_type == FCTaskType.OFFLOAD: - self.process_offload_task() + self.process_offload_task(dump=save_info_tuple is not None) elif task_type == FCTaskType.SET_EMA_STATE_DICT: ema_ckpt_path = task_body # mark ema state dict path else: