Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 139 additions & 135 deletions scripts/model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,129 +238,156 @@ def process_one_shard(shard_dir):
if args.test:
ref_state_dict = load_file(os.path.join(args.test_hf_dir, 'model.safetensors'))

def merge_across_tp(key, tp_data):
if "linear_fc1.weight" in key:
# if the tensor is gate and proj
gate_lst = []
up_lst = []
for infer_param in tp_data:
gate, up = infer_param.chunk(2)
gate_lst.append(gate)
up_lst.append(up)
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
tp_data = [gate, up]
elif "self_attention.linear_qkv." in key and 'layer_norm' not in key:
# if the tensor is qkv, for each param on tp, split into q, k, v
# concat q, k, v separately.
q_lst = []
k_lst = []
v_lst = []
assert config.num_attention_heads % config.num_key_value_heads == 0
num_q_per_kv = config.num_attention_heads // config.num_key_value_heads
assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0
kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2)
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
for infer_param in tp_data:
num_query_groups_per_partition = config.num_key_value_heads // tp_size
for chunk in infer_param.chunk(num_query_groups_per_partition):
split_size = [
kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
kv_size_per_tp // num_query_groups_per_partition,
kv_size_per_tp // num_query_groups_per_partition
]
q, k, v = chunk.split(split_size)
q_lst.append(q)
k_lst.append(k)
v_lst.append(v)
q = torch.cat(q_lst, dim=0)
k = torch.cat(k_lst, dim=0)
v = torch.cat(v_lst, dim=0)

tp_data = [q,k,v]

elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and args.is_value_model:
tp_data = tp_data[0]
def handle_qkv_proj(key, config, tensor, state_dict):
nonlocal tp_size

hidden_size_per_head = config.hidden_size // config.num_attention_heads

if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
q_part = tensor[:q_size_tp]
k_part = tensor[q_size_tp:q_size_tp + kv_size_tp]
v_part = tensor[q_size_tp + kv_size_tp:total_size]
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
q_part = tensor[:q_size_tp]
k_part = tensor[q_size_tp:q_size_tp + kv_size_tp]
v_part = tensor[q_size_tp + kv_size_tp:total_size]

preffix = '.'.join(key.split('.')[:4])
suffix = '.'.join(key.split('.')[5:])
if state_dict.get(f'{preffix}.q_proj.{suffix}') is None:
state_dict[f'{preffix}.q_proj.{suffix}'] = q_part
else:
dim = 0
if "linear_fc2.weight" in key or "self_attention.linear_proj" in key:
dim = 1
tp_data = torch.cat(tp_data, dim=dim)
state_dict[f'{preffix}.q_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.q_proj.{suffix}'], q_part], dim=0)
if state_dict.get(f'{preffix}.k_proj.{suffix}') is None:
state_dict[f'{preffix}.k_proj.{suffix}'] = k_part
else:
state_dict[f'{preffix}.k_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.k_proj.{suffix}'], k_part], dim=0)
if state_dict.get(f'{preffix}.v_proj.{suffix}') is None:
state_dict[f'{preffix}.v_proj.{suffix}'] = v_part
else:
state_dict[f'{preffix}.v_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.v_proj.{suffix}'], v_part], dim=0)

return state_dict

return tp_data
def handle_gate_up_proj(key, config, tensor, state_dict):
nonlocal tp_size

intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = tensor[:intermediate_size_tp]
up_weight_tp = tensor[intermediate_size_tp:]
preffix = '.'.join(key.split('.')[:4])
suffix = '.'.join(key.split('.')[5:])
if state_dict.get(f'{preffix}.gate_proj.{suffix}') is None:
state_dict[f'{preffix}.gate_proj.{suffix}'] = gate_weight_tp
else:
state_dict[f'{preffix}.gate_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.gate_proj.{suffix}'], gate_weight_tp], dim=0)
if state_dict.get(f'{preffix}.up_proj.{suffix}') is None:
state_dict[f'{preffix}.up_proj.{suffix}'] = up_weight_tp
else:
state_dict[f'{preffix}.up_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.up_proj.{suffix}'], up_weight_tp], dim=0)

return state_dict

def merge_between_tp_rank(key, model_state_dict):
nonlocal state_dict

try:
tensor = model_state_dict.pop(key)
except:
raise RuntimeError(f"error pop: {key}")
# Embedding layer
if "model.embed_tokens.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
else:
state_dict[key] = torch.concat([state_dict[key], tensor], dim=0)
return state_dict
# Tranformer Layers
if "input_layernorm.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
return state_dict
if re.search(r"self_attn\.qkv_proj", key):
state_dict = handle_qkv_proj(key, config, tensor, state_dict)
return state_dict
if "self_attn.o_proj.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
else:
state_dict[key] = torch.concat([state_dict[key], tensor], dim=1)
return state_dict
if "post_attention_layernorm.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
return state_dict
if re.search(r"mlp\.gate_up_proj\.weight", key):
state_dict = handle_gate_up_proj(key, config, tensor, state_dict)
return state_dict
if "mlp.down_proj.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
else:
state_dict[key] = torch.concat([state_dict[key], tensor], dim=1)
return state_dict
# Final LayerNorm
if "model.norm.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
return state_dict
if not args.tie_word_embedding:
if args.is_value_model:
if "lm_head.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
if "reward_head.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
else:
if "lm_head.weight" in key:
if state_dict[key] is None:
state_dict[key] = tensor
else:
state_dict[key] = torch.concat([state_dict[key], tensor], dim=0)
return state_dict
return state_dict

vpp_size = len(model_state_dict_lst[0][0])
layers_cum = 0
for vpp_rank in range(vpp_size):
for pp_rank in range(pp_size):
layers_handled = 0
keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys()
for pp_rank in range(pp_size):
print(f'pp_rank: {pp_rank}')
for vpp_rank, state_dict_single_layer in enumerate(model_state_dict_lst[pp_rank][0]):
state_dict_single_layer_iter = state_dict_single_layer.copy()
keys = state_dict_single_layer_iter.keys()
for key in keys:
if "extra_state" in key:
continue
if args.tie_word_embedding and ("output_layer" in key):
if args.tie_word_embedding and ("lm_head" in key or "reward_head" in key):
print(f'skip lm_head and reward_head loading because of tie_word_embeddings')
continue
new_key = key
if "decoder.layers." in key:
local_layer_no = int(key.split('.')[2])
layers_handled = max(local_layer_no, layers_handled)
global_layer_no = local_layer_no + layers_cum
new_key_list=key.split('.')
new_key_list[2] = str(global_layer_no)
new_key = '.'.join(new_key_list)

tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)]
merged = merge_across_tp(new_key, tp_data)
if not isinstance(merged,list):
state_dict[new_key] = merged
elif len(merged)==3:
# split qkv
for n,d in zip(['q','k','v'], merged):
state_dict[new_key.replace("linear_qkv",f"linear_{n}")] = d
elif len(merged)==2:
# split gate up
state_dict[new_key.replace("linear_fc1","gate_proj")] = merged[0]
state_dict[new_key.replace("linear_fc1","up_proj")] = merged[1]
layers_cum += layers_handled+1 # zero based

del model_state_dict_lst

params_mapping = [
# (megatron core gpt model name, vllm model name)
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
("self_attention.linear_q", "self_attn.q_proj"),
("self_attention.linear_k", "self_attn.k_proj"),
("self_attention.linear_v", "self_attn.v_proj"),
]
if re.search(r"self_attn\.qkv_proj", key) is None and re.search(r"gate_up_proj", key) is None:
state_dict[key] = None
for tp_rank in range(tp_size):
model_state_dict = model_state_dict_lst[pp_rank][tp_rank][vpp_rank]
state_dict = merge_between_tp_rank(key, model_state_dict)

del model_state_dict_lst
if args.test:

for original_name, loaded_weight in state_dict.items():
name = _replace_name(original_name, params_mapping)
if not name or name.endswith(".bias") and name not in ref_state_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
if args.tie_word_embedding and "lm_head.weight" in name:
continue
if name not in ref_state_dict:
raise RuntimeError(f'key: {name} not exist in state_dict')
param = ref_state_dict[name]
assert loaded_weight.dtype == param.dtype
torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4)
for key, value in state_dict.items():
print(key)
if key not in ref_state_dict:
raise RuntimeError(f'key: {key} not exist in ref_state_dict {value}')
if value.shape != ref_state_dict[key].shape:
raise RuntimeError(f'key: {key} shape mismatch {value.shape}, {ref_state_dict[key].shape}')
assert value.dtype == ref_state_dict[key].dtype, f'{key} state_dict[key].dtype: {value.dtype} != ref_state_dict[key].dtype: {ref_state_dict[key].dtype}'
torch.testing.assert_close(value, ref_state_dict[key], atol=1e-4, rtol=1e-4)
for key in ref_state_dict:
if key not in state_dict:
raise RuntimeError(f'key: {key} not exist in state_dict {ref_state_dict[key]}')


print('Writing to local disk')
if args.target_dir is None:
Expand Down Expand Up @@ -388,29 +415,6 @@ def merge_across_tp(key, tp_data):
if args.hf_upload_path:
upload_model_to_huggingface(hf_path)


def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if "layers" in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace("decoder", "model")
megatron_name_list = megatron_name.split(".")
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = ".".join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = ".".join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name

if __name__ == '__main__':
if args.backend == "fsdp":
convert_fsdp_checkpoints_to_hfmodels()
Expand Down
31 changes: 0 additions & 31 deletions verl/models/llama/megatron/layers/parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,3 @@ def __init__(self,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)


import torch


class LinearForLastLayer(torch.nn.Linear):

def __init__(
self,
input_size,
output_size,
*,
config,
bias=True,
):
super().__init__(in_features=input_size, out_features=output_size, bias=bias)
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel:
setattr(self.weight, 'sequence_parallel', True)

def forward(
self,
input_,
weight=None,
runtime_gather_output=None,
):
logits = super().forward(input_)
logits = logits.float()
if self.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits, None
16 changes: 0 additions & 16 deletions verl/models/mcore/__init__.py

This file was deleted.

Loading
Loading