Skip to content

Commit

Permalink
Feat/heterogeneous x pu training (#355)
Browse files Browse the repository at this point in the history
Co-authored-by: Wenwen Qu <[email protected]>
  • Loading branch information
KkHu-Kistch and blankde authored Nov 1, 2024
1 parent 058fb9b commit d6bfaca
Show file tree
Hide file tree
Showing 19 changed files with 1,170 additions and 179 deletions.
2 changes: 1 addition & 1 deletion configs/1.8B_MoE16_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=16,
moe_use_residual=False,
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D"
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless"
)
"""
zero1 parallel (dict):
Expand Down
2 changes: 1 addition & 1 deletion configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless"
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless"
num_experts=4,
top_k=2,
)
Expand Down
67 changes: 40 additions & 27 deletions internlm/checkpoint/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.model.moe.moe import MoE
from internlm.model.moe import MoE
from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
Expand All @@ -28,17 +28,23 @@
internlm_accelerator = get_accelerator()


def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pipeline_stage_size
# only support auto resume
def try_load_moe_checkpoint(folder, model, state_dict, expert_mp_rank, pp_rank):
"""Load MoE layer parameters from separate files if the model has MoE layers."""
# Calculate the stage size and rank within the pipeline parallelism
pp_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pp_stage_size
mode = "wp" if is_using_isp() else "tp"

# Iterate over all modules in the model to find MoE layers
for _, module in model.named_modules():
if isinstance(module, MoE):
num_local_experts = module.moe_layer.num_local_experts
num_local_wrapped_experts = len(module.moe_layer.experts.wrapped_experts)
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
for local_expert_id in range(num_local_wrapped_experts):
global_expert_id = expp_rank * num_local_wrapped_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_{mode}{expert_mp_rank}.pt"
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids
Expand All @@ -50,13 +56,14 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
moe_layer_id += 1


def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
def try_save_moe_checkpoint(folder, model, expert_mp_rank, pp_rank):
# Using layer_#_expert_# to save the model's expert state_dict,a hack.
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pipeline_stage_size
mode = "wp" if is_using_isp() else "tp"
for n_module, module in model.named_modules():
if isinstance(module, MoE):
num_local_experts = module.moe_layer.num_local_experts
num_local_wrapped_experts = len(module.moe_layer.experts.wrapped_experts)
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)

# get all moe parameters
Expand All @@ -76,7 +83,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
else:
local_expert_id = m.group(1)

global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
global_expert_id = expp_rank * num_local_wrapped_experts + int(local_expert_id)
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")

# truncating extra tensor (shared) storage
Expand All @@ -86,7 +93,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_{mode}{expert_mp_rank}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=expert_state_dict)
moe_layer_id += 1
Expand Down Expand Up @@ -179,10 +186,12 @@ def load_model_checkpoint(folder, model):
states[key] = states[key].float()
print("load: ", states[key].float(),flush=True)
"""

# try to load expert parameter to separate files if model have moe layer
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank)
if is_using_isp():
expert_wp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT)
try_load_moe_checkpoint(folder, model, states, expert_wp_rank, pp_rank)
else:
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank)

if gpc.config.parallel.zero1.fsdp:
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
Expand Down Expand Up @@ -252,6 +261,10 @@ def save_model_checkpoint(folder, model):
topo_fn = f"topo_wp{wp_rank}_pp{pp_rank}.json"
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
expert_wp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT)
expert_wdp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA)
if expert_wdp_rank == 0:
try_save_moe_checkpoint(folder, model, expert_wp_rank, pp_rank)
else:
# for tensor parallel mode with mtp/msp/fsp
for i in range(tp_size):
Expand All @@ -271,17 +284,17 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)

# try to save expert parameter to separate files if model have moe layer
expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA)
expert_tp_size = 1 if gpc.config.parallel.expert.no_tp else tp_size
expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA)
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
should_save_rank_pair.clear()
for i in range(expert_tp_size):
should_save_rank_pair.add((i, i % expert_dp_size))

if (expert_tp_rank, expert_dp_rank) in should_save_rank_pair:
try_save_moe_checkpoint(folder, model, expert_tp_rank, pp_rank)
# try to save expert parameter to separate files if model have moe layer
expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA)
expert_tp_size = 1 if gpc.config.parallel.expert.no_tp else tp_size
expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA)
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
should_save_rank_pair.clear()
for i in range(expert_tp_size):
should_save_rank_pair.add((i, i % expert_dp_size))

if (expert_tp_rank, expert_dp_rank) in should_save_rank_pair:
try_save_moe_checkpoint(folder, model, expert_tp_rank, pp_rank)

torch.distributed.barrier()

Expand Down
6 changes: 6 additions & 0 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
return "column"
elif linear_name in ("wo", "out_proj", "w2"):
return "row"
elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp":
return "grouped_wp"
elif linear_name in ("grouped_w1", "grouped_w3"):
return "grouped_column"
elif linear_name in ("grouped_w2"):
return "grouped_row"
else:
return "unknown"

Expand Down
7 changes: 2 additions & 5 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from internlm.core.context import global_context as gpc
from internlm.data.lumina_pickle.dataset import LuminaPickleDataset
from internlm.data.lumina_pickle.sampler import LuminaPickleBatchSampler
from internlm.data.lumina_pickle.collater import lumina_collate_fn
from internlm.data.megatron.batch_sampler import MegatronBatchSampler
from internlm.data.megatron.collaters import megatron_collate_fn
from internlm.data.megatron.dataset import build_megatron_dataset
Expand Down Expand Up @@ -218,11 +219,7 @@ def get_mock_train_loader_items(data_cfg):
def get_lumina_pickle_loader_items(data_cfg):
train_ds = LuminaPickleDataset(data_cfg.data_yaml, base_path=data_cfg.base_path, micro_batch_size=data_cfg.micro_bsz, seq_len=data_cfg.seq_len)
train_sampler = LuminaPickleBatchSampler(train_ds, micro_batch_size=data_cfg.micro_bsz, acc_grad=data_cfg.micro_num)
# TODO(zhenghuihuang): Can we reuse existing collate function?
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz)
#train_collate_fn = streaming_packed_collate_fn
#train_collate_fn = lambda batch: tuple(zip(*batch))
return train_ds, train_sampler, train_collate_fn
return train_ds, train_sampler, lumina_collate_fn

def build_train_loader_with_data_type():
"""
Expand Down
4 changes: 2 additions & 2 deletions internlm/data/lumina_pickle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .sampler import LuminaPickleSampler
from .sampler import LuminaPickleBatchSampler
from .dataset import LuminaPickleDataset

__all__ = [
"LuminaPickleSampler",
"LuminaPickleBatchSampler",
"LuminaPickleDataset"
]

34 changes: 34 additions & 0 deletions internlm/data/lumina_pickle/collater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

def lumina_collate_fn(batch):
"""
Collate function for packed input sequences.
Args:
batch (List[Dict]): List of dictionaries representing each sample in batch.
Each dictionary contains "tokens", "labels", "type_ids", "cu_seqlens", and "indexes" keys.
Returns:
Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
"cu_seqlens", "indexes", and "type_ids" keys, and the tensor of padded "labels".
"""
# Initialize lists to store the data from each sample
tokens, labels, type_ids, indexes = [], [], [], []
cumulative_seqlens = [0]

# Accumulate all samples into respective lists
for sample in batch:
tokens.extend([abs(w) for w in sample["tokens"]])
labels.extend([w if w > 0 else -100 for w in sample["labels"]])
type_ids.extend(sample["type_ids"])
indexes.extend(sample["indexes"])
cumulative_seqlens.append(cumulative_seqlens[-1] + sample["cu_seqlens"][-1])

# Convert lists to tensors and unsqueeze to add batch dimension
xs = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
ys = torch.tensor(labels, dtype=torch.long).unsqueeze(0)
ts = torch.tensor(type_ids, dtype=torch.long).unsqueeze(0)
indexes = torch.tensor(indexes, dtype=torch.long).unsqueeze(0)
cu_seqlens = torch.tensor(cumulative_seqlens, dtype=torch.int).unsqueeze(0)

return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts}, ys
Loading

0 comments on commit d6bfaca

Please sign in to comment.