Skip to content

Commit

Permalink
[flashcheckpoint] fix save under dp_degree > 1 with use-expert-parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Feb 6, 2025
1 parent 4c15ff2 commit c5e8642
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 11 deletions.
2 changes: 2 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
use_expert_parallel=self.args.use_expert_parallel,
ema_coef=self.args.flash_save_ema_coef,
)
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
Expand All @@ -721,6 +722,7 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
use_expert_parallel=self.args.use_expert_parallel,
ema_coef=self.args.flash_save_ema_coef,
)
_callback = FlashCheckpointCallback(
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ class TrainingArguments:
},
)
flash_save_ema_coef: Optional[float] = field(
default=0,
default=None,
metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"},
)
flash_ema_interval: Optional[int] = field(
Expand Down
88 changes: 79 additions & 9 deletions paddlenlp/trainer/utils/flash_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import paddle.autograd as imperative_base
import paddle.distributed as dist
from paddle.base import core
from paddle.distributed.fleet import fleet
from paddle.incubate.tensor.manipulation import (
async_offload_with_offset,
create_async_load,
Expand Down Expand Up @@ -374,7 +375,7 @@ def __init__(self, args, flash_checkpoint_manager, timer, sharding_io):
args.save_steps % args.flash_ema_interval == 0
), f"save_steps:{args.save_steps} must be divisible by flash_ema_interval:{args.flash_ema_interval}"
self.flash_ema_interval = args.flash_ema_interval
if args.flash_save_ema_coef:
if args.flash_save_ema_coef is not None:
assert args.flash_workers_num == 1, "[FC EMA] not support #worker > 1"

def on_substep_end(self, args, state, control, **kwargs):
Expand All @@ -392,7 +393,7 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
f"check coef: {args.flash_save_ema_coef} {control.should_save}, {state.global_step}, {self.flash_ema_interval}"
)
if not control.should_save:
if args.flash_save_ema_coef and state.global_step % self.flash_ema_interval == 0:
if args.flash_save_ema_coef is not None and state.global_step % self.flash_ema_interval == 0:
self.maybe_update_flash_checkpoint_worker(args, model, optimizer)
self.manager.get_idle_worker_for_saving() # prepare for dumping
else:
Expand Down Expand Up @@ -467,7 +468,7 @@ def _cache_meta_for_sharded_save(self, model):


class FlashCheckpointManager:
def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None):
def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, use_expert_parallel, ema_coef=None):
assert worker_num > 0, "worker_num must be greater than 0"
assert capacity_usage <= 1.0, "capacity_usage must be less than or equal to 1.0"
self.cache_version = 0
Expand All @@ -484,6 +485,7 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef
)
self.current_pipeline_hook_step = 0
ctx = multiprocessing.get_context("spawn")
assert hasattr(fleet, "_hcg"), "FlashCheckpoint Only support `use_hybrid_parallel`"
for i in range(worker_num):
worker_task_queue = ctx.Queue()
worker_status = ctx.Value("i", FCWorkerStatus.IDLE.value)
Expand All @@ -496,6 +498,11 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef
worker_task_queue,
worker_status,
worker_version,
use_expert_parallel,
fleet.get_hybrid_communicate_group().get_data_parallel_rank(),
fleet.get_hybrid_communicate_group().get_model_parallel_rank(),
fleet.get_hybrid_communicate_group()._get_pipe_parallel_id(),
fleet.get_hybrid_communicate_group().get_sharding_parallel_rank(),
ema_coef,
)
p = ctx.Process(target=worker_loop, args=(worker,))
Expand Down Expand Up @@ -613,7 +620,22 @@ def worker_loop(worker):


class FlashCheckpointWorker:
def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue, status, version, ema_coef=None):
def __init__(
self,
worker_id,
device_id,
global_rank,
offload_chunks,
task_queue,
status,
version,
use_expert_parallel,
dp_rank,
mp_rank,
pp_rank,
sd_rank,
ema_coef=None,
):
super().__init__()
self.worker_id = worker_id
self.device_id = device_id
Expand All @@ -623,6 +645,11 @@ def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue
self.status = status
self.version = version
self.ema_coef = ema_coef
self.use_expert_parallel = use_expert_parallel
self.dp_rank = dp_rank
self.mp_rank = mp_rank
self.pp_rank = pp_rank
self.sd_rank = sd_rank

# for dynamic objects saving
self.optimizer_fusion_storage_helper = None
Expand Down Expand Up @@ -706,7 +733,7 @@ def process_offload_task(self, dump):
if self.offloaded_numels == self.all_numel:
self.optimizer_fusion_storage_helper.wait_all()
self.param_fusion_storage_helper.wait_all()
if self.ema_coef:
if self.ema_coef is not None:
self.flash_ema_processor.ema_accumulate()
self.status.value = FCWorkerStatus.DUMPING.value

Expand Down Expand Up @@ -744,6 +771,35 @@ def process_dump_task(self):
need_report_error = True
return need_report_error

def _filter_moe_no_sync_optimizer_params(self, model_meta, optimzier_state_dict):
"""
filter optimizer params which should not sync, copy from paddlenlp.Trainer
"""
filter_optimzier_state_dict = OrderedDict()
assert "master_weights" in optimzier_state_dict, optimzier_state_dict.keys()
param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys())
filter_optimzier_state_dict["master_weights"] = OrderedDict()
suffix = f"tp{self.mp_rank:0>2d}_pp{self.pp_rank:0>2d}"
dyname_to_pname = model_meta["sharding_metas"][suffix]["structure_name_mapping"]
dyname_to_meta = model_meta["sharding_metas"][suffix]["param_meta"]
for k, pname in dyname_to_pname.items():
shape, dtype, is_dist, is_no_sync = dyname_to_meta[k]
if is_no_sync:
if pname in param_names_in_master_weights:
filter_optimzier_state_dict["master_weights"][pname] = optimzier_state_dict["master_weights"][
pname
]
else:
pass
# logger.info(f"filter out master weight:{pname} -> {k}")
for op_k, op_v in optimzier_state_dict.items():
if op_k.startswith(pname):
filter_optimzier_state_dict[op_k] = op_v
else:
# logger.info(f"filter out key={k}, when dp!=0")
pass
return filter_optimzier_state_dict

def process_dump_task_impl(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
# Step1: save static objects
Expand Down Expand Up @@ -771,14 +827,28 @@ def process_dump_task_impl(self, output_dir):
# Step2: save dynamic objects
# Step2.1: save model states
model_states_name_path = os.path.join(output_dir, self.model_states_name_path)
paddle.save(self.param_fusion_storage_helper.state_dict(), model_states_name_path)
state_dict = self.param_fusion_storage_helper.state_dict()

# Step2.2: save optimizer states
optimizer_state_name_path = os.path.join(output_dir, self.optimizer_states_name_path)
paddle.save(self.optimizer_fusion_storage_helper.state_dict(), optimizer_state_name_path)
if self.ema_coef:
opt_state_dict = self.optimizer_fusion_storage_helper.state_dict()

if self.ema_coef is not None:
ema_name_path = os.path.join(output_dir, self.optimizer_states_name_path).replace("optimizer", "ema")
paddle.save(self.flash_ema_processor.ema_state_dict(), ema_name_path)
ema_state_dict = self.flash_ema_processor.ema_state_dict()

if self.dp_rank <= 0 or self.use_expert_parallel:
if self.dp_rank > 0: # ep
opt_state_dict = self._filter_moe_no_sync_optimizer_params(self.model_meta_content, opt_state_dict)
if self.ema_coef is not None:
# non master-weights in `ema-state-dict` when dp >1 will be filterd, which is acceptable
ema_state_dict = self._filter_moe_no_sync_optimizer_params(self.model_meta_content, ema_state_dict)

paddle.save(state_dict, model_states_name_path)
paddle.save(opt_state_dict, optimizer_state_name_path)
if self.ema_coef is not None:
paddle.save(ema_state_dict, ema_name_path)

# Step2.3: save LR Scheduler (To be removed)
lr_state_name_path = os.path.join(output_dir, SCHEDULER_NAME)
if self.device_id == 0:
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,8 @@ def _gather_sharding_metas(self):
for k, v in model.state_dict().items():
structure_name_mapping[k] = v.name
is_distributed = getattr(v, "is_distributed", False)
param_meta[k] = (v.shape, int(v.dtype), is_distributed)
no_sync = getattr(v, "no_sync", False)
param_meta[k] = (v.shape, int(v.dtype), is_distributed, no_sync)

sharding_metas = {}
sharding_meta = {}
Expand Down

0 comments on commit c5e8642

Please sign in to comment.