Skip to content

Commit

Permalink
[Unified Checkpoint] Fix expert parallel (PaddlePaddle#9821)
Browse files Browse the repository at this point in the history
* fix expert parallel

* fix split_param for expert parallel

* add filter_sync_parameters
  • Loading branch information
DesmonDay committed Feb 13, 2025
1 parent dad92bd commit 90e2e14
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,11 @@ def load_resolved_archive_file(
)
)
if has_master_weights:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
# for parameters with float32 dtype, no need to have fp32 master weights.
key_name = "_".join([static_name, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])

Expand Down
30 changes: 10 additions & 20 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
FP32_MASTER,
UnifiedCheckpointOption,
filter_params,
filter_sync_parameters,
gather_sharded_object,
generate_base_static_name,
get_expected_state_dict,
Expand Down Expand Up @@ -218,25 +219,9 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)

no_sync_kname = []
model_state_dict = get_expected_state_dict(model)
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
if self.args.use_expert_parallel:
for k in list(optim_state_dict.keys()):
model_k = k.split("/")[0]
if dp_rank > 0 and model_k not in no_sync_kname:
optim_state_dict.pop(k)
if master_weights is not None:
for k in list(master_weights.keys()):
model_k = k.split("/")[0]
if dp_rank > 0 and model_k not in no_sync_kname:
master_weights.pop(k)
model_state_dict = get_expected_state_dict(model)
filter_sync_parameters(model_state_dict, optim_state_dict, master_weights, is_model_weight=False)

optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)
Expand Down Expand Up @@ -516,6 +501,10 @@ def unified_checkpoint_into_shards(

config_to_save = copy.deepcopy(model_to_save.config)

if args.use_expert_parallel:
# ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0.
filter_sync_parameters(state_dict, is_model_weight=True)

if config_to_save.tensor_parallel_degree > 1:
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
tp_actions = model_to_save._get_tensor_parallel_convert_actions(
Expand Down Expand Up @@ -615,6 +604,9 @@ def unified_optimizer_into_shards(
tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
tp_size = tp_group.nranks

if args.use_expert_parallel:
filter_sync_parameters(state_dict, optim_state_dict, master_weights, is_model_weight=False)

if tp_size > 1:
# get tp_actions
model_keys = []
Expand All @@ -633,7 +625,6 @@ def unified_optimizer_into_shards(
optim_state_dict,
tp_actions,
filter_optim_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()

Expand All @@ -643,7 +634,6 @@ def unified_optimizer_into_shards(
master_weights,
tp_actions,
filter_master_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()

Expand Down
50 changes: 33 additions & 17 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
"""
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()
tp_rank = tp_group.rank
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

# filter actions for pipeline mode
if hcg.get_pipe_parallel_group().nranks > 1:
Expand All @@ -382,10 +380,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
if i > len(filter_keys) - 1:
continue
key = filter_keys[i]
tensor = state_dict[key]
# When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
if dp_rank > 0 and not getattr(tensor, "no_sync", False):
if key not in state_dict:
continue
tensor = state_dict[key]
if key in tp_actions:
# Get tensor size
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
Expand Down Expand Up @@ -414,21 +411,13 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
return state_dict_to_save


def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None):
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
"""
Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight.
"""
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()
tp_rank = tp_group.rank
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

no_sync_kname = []
if model_state_dict is not None:
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

state_dict_to_save = {}
max_key_len = max([len(_) for _ in all_filter_keys])
Expand All @@ -439,10 +428,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys,
continue
# get base model key
model_key = filter_keys[i].split("/")[0]
tensor = state_dict[filter_keys[i]]
# When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
if dp_rank > 0 and model_key not in no_sync_kname:
if filter_keys[i] not in state_dict:
continue
tensor = state_dict[filter_keys[i]]
if model_key in tp_actions:
# for example: beta1, beta2
if tensor.numel().item() == 1:
Expand Down Expand Up @@ -779,3 +767,31 @@ def save_config(model_to_save):
# save generation config
if model_to_save.can_generate():
model_to_save.generation_config.save_pretrained(save_directory)


def filter_sync_parameters(model_state_dict, optim_state_dict=None, master_weights=None, is_model_weight=True):
"""Filter sync parameters under expert parallel mode."""

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

if is_model_weight:
for key in list(model_state_dict.keys()):
if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False):
model_state_dict.pop(key)
else:
no_sync_kname = []
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

for key in list(optim_state_dict.keys()):
model_key = key.split("/")[0]
if dp_rank > 0 and model_key not in no_sync_kname:
optim_state_dict.pop(key)

if master_weights is not None:
for key in list(master_weights.keys()):
if dp_rank > 0 and key not in no_sync_kname:
master_weights.pop(key)
176 changes: 176 additions & 0 deletions tests/trainer/test_moe_unified_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import pytest

from paddlenlp.utils.downloader import get_path_from_url_with_filelock
from tests.parallel_launch import TestMultipleGpus
from tests.testing_utils import require_paddle_at_least_8_gpu, skip_for_none_ce_case
from tests.trainer.test_unified_checkpoint import remove_ckpt, remove_logs
from tests.trainer.trainer_utils import get_pretrain_arguments

environment_variables = {
"NCCL_ALGO": "Tree",
"NVIDIA_TF32_OVERRIDE": "0",
"NCCL_IB_TIMEOUT": "22",
"NCCL_DEBUG": "INFO",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
"Flags_mp_aysnc_allreduce": "1",
"Flags_skip_mp_c_identity": "1",
"FLAGS_shard_norm_align_dp": "0",
"FLAGS_shard_use_reduce": "1",
"test_ci_no_save_model": "1",
}

moe_arguments = {
"model_name_or_path": "__internal_testing__/unified-ckpt-qwen2moe",
"dataset_name_or_path": "./unified_checkpoint/peft_input/data/",
"output_dir": "./unified_checkpoint/checkpoints/qwen2moe_sft_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps": 16,
"learning_rate": 3e-04,
"max_steps": 10,
"save_steps": 6,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "no",
"save_strategy": "steps",
"src_length": 1024,
"max_length": 2048,
"bf16": "true",
"fp16_opt_level": "O2",
"do_train": "true",
"do_eval": "false",
"disable_tqdm": "true",
"eval_with_do_generation": "false",
"recompute": "true",
"recompute_granularity": "full",
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"sharding": "",
"lora": "false",
"zero_padding": "false",
"use_flash_attention": "false",
"unified_checkpoint": 1,
"continue_training": 0,
"sequence_parallel": 0,
}


def check_acc(log_dir="log"):
file_path = os.path.join(log_dir, "workerlog.n0.c0")
cmd = "grep -a 'global_step: 10' " + file_path + " | awk -F ',' '{print $2}' | awk '{print $6}'"
import subprocess

res = subprocess.check_output(cmd, shell=True, text=True)
res = [float(x) for x in res.split()]

return res


seed = 2024

rng = np.random.default_rng(seed=seed)


@pytest.mark.xdist_group(name="UC")
class TestUnifiedCheckpointBase(TestMultipleGpus):
@classmethod
@property
def __test__(cls):
return cls != TestUnifiedCheckpointBase

def setUp(self):
"""
1. update runfirst and rerun to run defined different config
2. update need_allclose to True if you want to check the result
3. update rtol to the relative value you want to check
"""

self.configs = get_pretrain_arguments(moe_arguments)
os.environ.update(environment_variables)

file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz"
input_dir = "unified_checkpoint/peft_input/"
os.makedirs(input_dir, exist_ok=True)
file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz")
if not os.path.exists(file_path):
get_path_from_url_with_filelock(file_, root_dir=input_dir)

self.need_allclose = True
self.rtol = 1e-7

self.run_file = "llm/run_finetune.py"

def runfirst(self, train_args):
self.run_n1c8(self.run_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_file, **train_args)

@require_paddle_at_least_8_gpu
def testTP4DP2(self):
remove_logs()
remove_ckpt(moe_arguments["output_dir"])

train_args = self.configs["TP4DP2"]
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)

@skip_for_none_ce_case
@require_paddle_at_least_8_gpu
def testTP2Sharding4(self):
remove_logs()
remove_ckpt(moe_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)


@pytest.mark.xdist_group(name="UC")
class TestUnifiedCheckpointFull(TestUnifiedCheckpointBase):
@skip_for_none_ce_case
@require_paddle_at_least_8_gpu
def testTP2Sharding4V2(self):
remove_logs()
remove_ckpt(moe_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
train_args.update({"sharding_parallel_config": "split_param"})
train_args.update({"amp_master_grad": True})
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)
8 changes: 8 additions & 0 deletions tests/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def get_pretrain_arguments(pretrain_arguments):
train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 8
configs["DP8"] = train_args

train_args = copy.deepcopy(pretrain_arguments)
train_args["tensor_parallel_degree"] = 2
train_args["pipeline_parallel_degree"] = 1
train_args["sharding_parallel_degree"] = 2
train_args["sharding"] = "stage1"
train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 4
configs["TP2DP2Sharding2"] = train_args

return configs


Expand Down

0 comments on commit 90e2e14

Please sign in to comment.