diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 9e12523b2bc..962ebd0c86e 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -238,129 +238,156 @@ def process_one_shard(shard_dir): if args.test: ref_state_dict = load_file(os.path.join(args.test_hf_dir, 'model.safetensors')) - def merge_across_tp(key, tp_data): - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in tp_data: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - tp_data = [gate, up] - elif "self_attention.linear_qkv." in key and 'layer_norm' not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 - kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - for infer_param in tp_data: - num_query_groups_per_partition = config.num_key_value_heads // tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - - tp_data = [q,k,v] - - elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and args.is_value_model: - tp_data = tp_data[0] + def handle_qkv_proj(key, config, tensor, state_dict): + nonlocal tp_size + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + q_part = tensor[:q_size_tp] + k_part = tensor[q_size_tp:q_size_tp + kv_size_tp] + v_part = tensor[q_size_tp + kv_size_tp:total_size] + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + q_part = tensor[:q_size_tp] + k_part = tensor[q_size_tp:q_size_tp + kv_size_tp] + v_part = tensor[q_size_tp + kv_size_tp:total_size] + + preffix = '.'.join(key.split('.')[:4]) + suffix = '.'.join(key.split('.')[5:]) + if state_dict.get(f'{preffix}.q_proj.{suffix}') is None: + state_dict[f'{preffix}.q_proj.{suffix}'] = q_part else: - dim = 0 - if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: - dim = 1 - tp_data = torch.cat(tp_data, dim=dim) + state_dict[f'{preffix}.q_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.q_proj.{suffix}'], q_part], dim=0) + if state_dict.get(f'{preffix}.k_proj.{suffix}') is None: + state_dict[f'{preffix}.k_proj.{suffix}'] = k_part + else: + state_dict[f'{preffix}.k_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.k_proj.{suffix}'], k_part], dim=0) + if state_dict.get(f'{preffix}.v_proj.{suffix}') is None: + state_dict[f'{preffix}.v_proj.{suffix}'] = v_part + else: + state_dict[f'{preffix}.v_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.v_proj.{suffix}'], v_part], dim=0) + return state_dict - return tp_data + def handle_gate_up_proj(key, config, tensor, state_dict): + nonlocal tp_size + + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = tensor[:intermediate_size_tp] + up_weight_tp = tensor[intermediate_size_tp:] + preffix = '.'.join(key.split('.')[:4]) + suffix = '.'.join(key.split('.')[5:]) + if state_dict.get(f'{preffix}.gate_proj.{suffix}') is None: + state_dict[f'{preffix}.gate_proj.{suffix}'] = gate_weight_tp + else: + state_dict[f'{preffix}.gate_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.gate_proj.{suffix}'], gate_weight_tp], dim=0) + if state_dict.get(f'{preffix}.up_proj.{suffix}') is None: + state_dict[f'{preffix}.up_proj.{suffix}'] = up_weight_tp + else: + state_dict[f'{preffix}.up_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.up_proj.{suffix}'], up_weight_tp], dim=0) + + return state_dict + + def merge_between_tp_rank(key, model_state_dict): + nonlocal state_dict + + try: + tensor = model_state_dict.pop(key) + except: + raise RuntimeError(f"error pop: {key}") + # Embedding layer + if "model.embed_tokens.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + else: + state_dict[key] = torch.concat([state_dict[key], tensor], dim=0) + return state_dict + # Tranformer Layers + if "input_layernorm.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + return state_dict + if re.search(r"self_attn\.qkv_proj", key): + state_dict = handle_qkv_proj(key, config, tensor, state_dict) + return state_dict + if "self_attn.o_proj.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + else: + state_dict[key] = torch.concat([state_dict[key], tensor], dim=1) + return state_dict + if "post_attention_layernorm.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + return state_dict + if re.search(r"mlp\.gate_up_proj\.weight", key): + state_dict = handle_gate_up_proj(key, config, tensor, state_dict) + return state_dict + if "mlp.down_proj.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + else: + state_dict[key] = torch.concat([state_dict[key], tensor], dim=1) + return state_dict + # Final LayerNorm + if "model.norm.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + return state_dict + if not args.tie_word_embedding: + if args.is_value_model: + if "lm_head.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + if "reward_head.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + else: + if "lm_head.weight" in key: + if state_dict[key] is None: + state_dict[key] = tensor + else: + state_dict[key] = torch.concat([state_dict[key], tensor], dim=0) + return state_dict + return state_dict - vpp_size = len(model_state_dict_lst[0][0]) - layers_cum = 0 - for vpp_rank in range(vpp_size): - for pp_rank in range(pp_size): - layers_handled = 0 - keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() + for pp_rank in range(pp_size): + print(f'pp_rank: {pp_rank}') + for vpp_rank, state_dict_single_layer in enumerate(model_state_dict_lst[pp_rank][0]): + state_dict_single_layer_iter = state_dict_single_layer.copy() + keys = state_dict_single_layer_iter.keys() for key in keys: if "extra_state" in key: continue - if args.tie_word_embedding and ("output_layer" in key): + if args.tie_word_embedding and ("lm_head" in key or "reward_head" in key): print(f'skip lm_head and reward_head loading because of tie_word_embeddings') continue - new_key = key - if "decoder.layers." in key: - local_layer_no = int(key.split('.')[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list=key.split('.') - new_key_list[2] = str(global_layer_no) - new_key = '.'.join(new_key_list) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = merge_across_tp(new_key, tp_data) - if not isinstance(merged,list): - state_dict[new_key] = merged - elif len(merged)==3: - # split qkv - for n,d in zip(['q','k','v'], merged): - state_dict[new_key.replace("linear_qkv",f"linear_{n}")] = d - elif len(merged)==2: - # split gate up - state_dict[new_key.replace("linear_fc1","gate_proj")] = merged[0] - state_dict[new_key.replace("linear_fc1","up_proj")] = merged[1] - layers_cum += layers_handled+1 # zero based - - del model_state_dict_lst - - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ("self_attention.linear_q", "self_attn.q_proj"), - ("self_attention.linear_k", "self_attn.k_proj"), - ("self_attention.linear_v", "self_attn.v_proj"), - ] + if re.search(r"self_attn\.qkv_proj", key) is None and re.search(r"gate_up_proj", key) is None: + state_dict[key] = None + for tp_rank in range(tp_size): + model_state_dict = model_state_dict_lst[pp_rank][tp_rank][vpp_rank] + state_dict = merge_between_tp_rank(key, model_state_dict) + del model_state_dict_lst if args.test: - - for original_name, loaded_weight in state_dict.items(): - name = _replace_name(original_name, params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if args.tie_word_embedding and "lm_head.weight" in name: - continue - if name not in ref_state_dict: - raise RuntimeError(f'key: {name} not exist in state_dict') - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4) + for key, value in state_dict.items(): + print(key) + if key not in ref_state_dict: + raise RuntimeError(f'key: {key} not exist in ref_state_dict {value}') + if value.shape != ref_state_dict[key].shape: + raise RuntimeError(f'key: {key} shape mismatch {value.shape}, {ref_state_dict[key].shape}') + assert value.dtype == ref_state_dict[key].dtype, f'{key} state_dict[key].dtype: {value.dtype} != ref_state_dict[key].dtype: {ref_state_dict[key].dtype}' + torch.testing.assert_close(value, ref_state_dict[key], atol=1e-4, rtol=1e-4) + for key in ref_state_dict: + if key not in state_dict: + raise RuntimeError(f'key: {key} not exist in state_dict {ref_state_dict[key]}') + print('Writing to local disk') if args.target_dir is None: @@ -388,29 +415,6 @@ def merge_across_tp(key, tp_data): if args.hf_upload_path: upload_model_to_huggingface(hf_path) - -def _replace_name(megatron_name, name_mapping): - for m_name, v_name in name_mapping: - if m_name not in megatron_name: - continue - if "layers" in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace("decoder", "model") - megatron_name_list = megatron_name.split(".") - if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: - param_name_list = megatron_name_list[:3] - param_name_list.append(v_name) - param_name = ".".join(param_name_list) - else: - param_name_list = megatron_name_list[:3] - weight_or_bias = megatron_name_list[-1] - param_name_list.append(v_name) - param_name_list.append(weight_or_bias) - param_name = ".".join(param_name_list) - return param_name - else: - param_name = megatron_name.replace(m_name, v_name) - return param_name - if __name__ == '__main__': if args.backend == "fsdp": convert_fsdp_checkpoints_to_hfmodels() diff --git a/verl/models/llama/megatron/layers/parallel_linear.py b/verl/models/llama/megatron/layers/parallel_linear.py index e3e4e438554..bfe5cf4e65e 100644 --- a/verl/models/llama/megatron/layers/parallel_linear.py +++ b/verl/models/llama/megatron/layers/parallel_linear.py @@ -72,34 +72,3 @@ def __init__(self, gather_output=gather_output, skip_bias_add=skip_bias_add, **kwargs) - - -import torch - - -class LinearForLastLayer(torch.nn.Linear): - - def __init__( - self, - input_size, - output_size, - *, - config, - bias=True, - ): - super().__init__(in_features=input_size, out_features=output_size, bias=bias) - self.sequence_parallel = config.sequence_parallel - if self.sequence_parallel: - setattr(self.weight, 'sequence_parallel', True) - - def forward( - self, - input_, - weight=None, - runtime_gather_output=None, - ): - logits = super().forward(input_) - logits = logits.float() - if self.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits, None diff --git a/verl/models/mcore/__init__.py b/verl/models/mcore/__init__.py deleted file mode 100644 index e4e2a3861a8..00000000000 --- a/verl/models/mcore/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .gpt_model import gptmodel_forward \ No newline at end of file diff --git a/verl/models/mcore/gpt_model.py b/verl/models/mcore/gpt_model.py deleted file mode 100644 index 172857b797b..00000000000 --- a/verl/models/mcore/gpt_model.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from verl.utils.megatron import sequence_parallel as sp_utils -from verl.utils.megatron import tensor_parallel as tp_utils -import torch -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core import parallel_state as mpu -from verl.utils.megatron_utils import unwrap_model - - -def gptmodel_forward(model, - input_ids, - attention_mask, - position_ids, - sequence_parallel, - value_model=False, - pack_seqs=True): - pre_process = unwrap_model(model).pre_process - post_process = unwrap_model(model).post_process - if pack_seqs: - batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) - input_ids_rmpad = input_ids_rmpad.contiguous() - output = model(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids, - packed_seq_params=packed_seq_params) - output = postprocess_packed_seqs(output, - packed_seq_params, - attention_mask, - batch_size, - seq_len, - post_process=post_process) - else: - batch_size, sequence_length = attention_mask.shape - new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, - attention_mask, - position_ids, - sequence_parallel, - pre_process=pre_process) - output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) - output = recover_left_padding(output, - new_attention_mask, - attention_mask, - sequence_length, - post_process=post_process) - if value_model and post_process: - output = output[..., 0] - return output - - -def preprocess_packed_seqs(input_ids: torch.Tensor, - attention_mask: torch.Tensor, - pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]: - """ - Preprocess packed sequences - """ - batch_size = input_ids.shape[0] - - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - tp_size = mpu.get_tensor_model_parallel_world_size() - pad_size = (tp_size - seqlens_in_batch % tp_size) % tp_size - seqlens_in_batch_padded = seqlens_in_batch + pad_size - cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) - cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) - cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) - cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) - max_seqlen_in_batch = seqlens_in_batch_padded.max().item() - - shape = list(input_ids.shape[1:]) - shape[0] = seqlens_in_batch_padded.sum().item() - if pre_process: - input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) - for i in range(batch_size): - seqlen = seqlens_in_batch[i] - input_ids_rmpad[cu_seqlens_padded[i]:cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] - - packed_seq_params = PackedSeqParams(qkv_format='thd', - cu_seqlens_q=cu_seqlens_padded, - max_seqlen_q=max_seqlen_in_batch, - cu_seqlens_kv=cu_seqlens_padded, - max_seqlen_kv=max_seqlen_in_batch, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded) - if pre_process: - return input_ids_rmpad.unsqueeze(0), packed_seq_params - else: - return input_ids, packed_seq_params - - -def postprocess_packed_seqs(output: torch.Tensor, - packed_seq_params: PackedSeqParams, - attention_mask: torch.Tensor, - batch_size: int, - seq_len: int, - post_process: bool = True) -> torch.Tensor: - """ - Postprocess packed sequences - """ - if not post_process: - return output - shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim - output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) - for i in range(batch_size): - s = attention_mask[i].sum().item() - output_new[i, - attention_mask[i]] = output[0][packed_seq_params. - cu_seqlens_q_padded[i]:packed_seq_params.cu_seqlens_q_padded[i] + s] - - return output_new - - -def remove_left_padding(input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - sequence_parallel: bool = False, - pre_process: bool = True): - """ - Remove left padding from input_ids, attention_mask and position_ids - return new_input_ids, new_attention_mask, new_position_ids - """ - assert attention_mask.ndim == 2 - assert position_ids.ndim == 2 - - batch_size = input_ids.shape[0] - shape = list(input_ids.shape) # batch_size, seq_len,... - seq_lens = attention_mask.sum(dim=1) - seq_len = seq_lens.max().item() - if sequence_parallel: - from megatron.core import parallel_state as mpu - sp_world_size = mpu.get_tensor_model_parallel_world_size() - pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size - seq_len = seq_len + pad_size - shape[1] = seq_len - if pre_process: - new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) - new_attention_mask = torch.zeros(dtype=attention_mask.dtype, - device=attention_mask.device, - size=(batch_size, seq_len)) - new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) - for i in range(batch_size): - if pre_process: - new_input_ids[i, :seq_lens[i]] = input_ids[i, attention_mask[i]] - new_attention_mask[i, :seq_lens[i]] = attention_mask[i, attention_mask[i]] - new_position_ids[i, :seq_lens[i]] = position_ids[i, attention_mask[i]] - if pre_process: - return new_input_ids, new_attention_mask, new_position_ids - else: - return input_ids, new_attention_mask, new_position_ids - - -def recover_left_padding(result, - attention_mask: torch.Tensor, - original_attention_mask: torch.Tensor, - origin_seqlen: int, - post_process: bool = True): - """ - Recover left padding from result - return result - """ - if not post_process: - return result - shape = list(result.shape) - batch_size = shape[0] - shape[1] = origin_seqlen - new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) - for i in range(batch_size): - new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] - return new_result diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py deleted file mode 100644 index 1a782aa18fd..00000000000 --- a/verl/models/mcore/loader.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import time -import torch.distributed as dist - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - import megatron - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): - """Load merged state_dict to sharded Megatron module in training. - """ - import megatron - from megatron.core import mpu - from verl.utils.megatron_utils import print_rank_0, unwrap_model - from megatron.core.transformer.module import Float16Module - from megatron.core import DistributedDataParallel as LocalDDP - from torch.nn.parallel import DistributedDataParallel as torchDDP - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast(param.data, - src=mpu.get_data_parallel_src_rank(), - group=mpu.get_data_parallel_group()) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, (list, tuple)): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.decoder.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == 0: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - if torch.distributed.get_rank() == 0: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=0, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, - config.hidden_size, - dtype=params_dtype, - device=torch.cuda.current_device()) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0)) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - else: - assert ( - tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - sizes = [total_size * tp_size] - if not bias: - sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] - num_query_groups_per_partition = models[0].config.num_query_groups // tp_size - new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) - k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) - v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) - total_size_per_head = total_size // num_query_groups_per_partition - for j in range(num_query_groups_per_partition): - new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - sizes = [total_size * tp_size] - if not bias: - sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) - k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) - v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) - total_size_per_head = total_size // config.num_attention_heads - for j in range(config.num_attention_heads): - new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - else: - assert (tensor.shape == chunk_shape - ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - if f"{layer_name}.self_attn.q_proj.bias" in state_dict: - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True) - - _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") - - _broadcast_tp_shard_tensor( - sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.decoder.final_layernorm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.output_layer.weight - - if is_value_model: - # if torch.distributed.get_rank() == 0: - if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0('load lm_head from value_head weight') - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0('fail to match lm_head in value_model') - # else: - - # _broadcast_tensor(lm_head_weight, "lm_head.weight") - - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - pass - torch.cuda.empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py deleted file mode 100644 index 446bc3120ac..00000000000 --- a/verl/models/mcore/saver.py +++ /dev/null @@ -1,455 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from verl.utils.megatron_utils import print_rank_0, unwrap_model -from megatron.core import mpu -from megatron.core.transformer.module import Float16Module -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from torch.nn.parallel import DistributedDataParallel as torchDDP -import torch -import time - -import torch -import torch.distributed as dist - - -def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): - """given TP,DP,PP rank to get the global rank.""" - - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size() - ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - # We only support TP-DP-PP grouping, for correctness when resharding - return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + - pp_rank_idx * num_layers_per_model) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, (list, tuple)): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].decoder.layers - ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format( - len(models[i].decoder.layers), num_layers_per_model) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - chunk_shape = tensor.shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - chunk_shape = tensor.shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - chunk_shape = tensor.shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size:(i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - torch.cuda.empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.decoder.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - if getattr(sync_layer.self_attention.linear_qkv, 'bias', - None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0: - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.linear_fc2.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.decoder.final_layernorm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - if tie_word_embeddings: - print_rank_0(f"tie word embedding skip load lm_head...") - else: - print_rank_0("collecting lm_head...") - - if is_value_model: - lm_head_weight = None - if pp_rank == pp_size - 1: - lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) - _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - torch.cuda.empty_cache() - if torch.distributed.get_rank() == 0: - - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py index a700761268b..18249d4c0a0 100644 --- a/verl/models/weight_loader_registry.py +++ b/verl/models/weight_loader_registry.py @@ -16,10 +16,9 @@ def get_weight_loader(arch: str): from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2 - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { - 'LlamaForCausalLM': load_state_dict_to_megatron_gptmodel, - 'Qwen2ForCausalLM': load_state_dict_to_megatron_gptmodel, + 'LlamaForCausalLM': load_state_dict_to_megatron_llama, + 'Qwen2ForCausalLM': load_state_dict_to_megatron_qwen2, } if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: @@ -31,10 +30,9 @@ def get_weight_loader(arch: str): def get_weight_saver(arch: str): from verl.models.llama.megatron.checkpoint_utils.llama_saver import merge_megatron_ckpt_llama from verl.models.qwen2.megatron.checkpoint_utils.qwen2_saver import merge_megatron_ckpt_qwen2 - from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { - 'LlamaForCausalLM': merge_megatron_ckpt_gptmodel, - 'Qwen2ForCausalLM': merge_megatron_ckpt_gptmodel, + 'LlamaForCausalLM': merge_megatron_ckpt_llama, + 'Qwen2ForCausalLM': merge_megatron_ckpt_qwen2, } if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py index 8f6b7133ac8..4a8a5933ba3 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -179,6 +179,67 @@ def _replace_name(megatron_name, name_mapping): return param_name +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def _replace_name(megatron_name, name_mapping): for m_name, v_name in name_mapping: if m_name not in megatron_name: @@ -214,38 +275,6 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - weight_loader(param, loaded_weight) -def megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ] - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for original_name, loaded_weight in actor_weights.items(): - name = _replace_name(original_name, params_mapping) - if not name or name.endswith(".bias") and name not in params_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { ColumnParallelLinear: parallel_weight_loader, MergedColumnParallelLinear: parallel_weight_loader, @@ -263,10 +292,10 @@ def megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { "GPT2LMHeadModel": gpt2_weight_loader, - "LlamaForCausalLM": megatron_core_te_weight_loader, # use te backend for open-source megatron - "LLaMAForCausalLM": megatron_core_te_weight_loader, + "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_weight_loader, "MistralForCausalLM": mistral_megatron_weight_loader, - 'Qwen2ForCausalLM': megatron_core_te_weight_loader, + 'Qwen2ForCausalLM': qwen2_megatron_weight_loader, } diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index c001e4e2ef3..380b7fcaca5 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -195,6 +195,18 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte assert len(state_dicts) == len( self.model), f'state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}' for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)): + # modify layer numbers + offset = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)).model.layers[0].layer_idx + + state_dict_old = state_dict.copy() + old_keys = state_dict_old.keys() + for k in old_keys: + if k.split('.')[1] == 'layers': + layer_idx = int(k.split('.')[2]) + new_key = '.'.join(k.split('.')[:2] + [str(layer_idx - offset)] + k.split('.')[3:]) + state_dict[new_key] = state_dict[k] + if new_key != k: + state_dict.pop(k) model.load_state_dict(state_dict) print(f'Loaded sharded model checkpoint from {model_path}') @@ -231,6 +243,19 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i for vpp_rank, model in enumerate(self.model): state_dict = model.state_dict() + + # modify layer numbers + offset = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)).model.layers[0].layer_idx + state_dict_old = state_dict.copy() + old_keys = state_dict_old.keys() + for k in old_keys: + if k.split('.')[1] == 'layers': + layer_idx = int(k.split('.')[2]) + new_key = '.'.join(k.split('.')[:2] + [str(layer_idx + offset)] + k.split('.')[3:]) + state_dict[new_key] = state_dict[k] + if new_key != k: + state_dict.pop(k) + state_dicts.append(state_dict) print(f'Saving sharded model checkpoint to {local_path}') diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 9812639b5f2..d6fb58b9f20 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -157,10 +157,6 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC print(f'megatron config {megatron_config}') dt = PrecisionType.to_dtype(megatron_config.params_dtype) print(f'pipeline_dtype=megatron_config {dt}') - if "Qwen2ForCausalLM" in hf_config.architectures: - qkv_bias = True - else: - qkv_bias = getattr(hf_config, 'attention_bias', False) overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size( ) is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 batch_p2p_comm = False @@ -189,9 +185,6 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC variable_seq_lengths=True, masked_softmax_fusion=True, moe_token_dispatcher_type="alltoall", - attention_dropout=hf_config.attention_dropout, - hidden_dropout=getattr(hf_config, 'hidden_dropout', 0.0), - add_qkv_bias=qkv_bias, bf16=dt is torch.bfloat16) return transformer_config diff --git a/verl/utils/model.py b/verl/utils/model.py index 6af6542ba83..f3b96241d0c 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -365,81 +365,3 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size) return unpad_tokens, cu_seqlens, max_seqlen_in_batch - - -def load_megatron_gptmodel_weights(config, - model_config, - parallel_model, - params_dtype, - is_value_model=False, - local_cache_path='~/.cache/verl/rlhf'): - assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!" - architectures = getattr(model_config, "architectures", []) - local_cache_path = os.path.expanduser(local_cache_path) - - if config.model.path.startswith("hdfs:"): - from verl.utils.fs import copy_to_local - print(f'start download from {config.model.path}') - local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path) - print('finish download') - else: - print(f"load from local dir {config.model.path}") - local_model_path = config.model.path - - # TODO: to find a better way to load mistral7b-rm lm_head - if 'mistral7b-rm' in config.model.path: - model = MistralForSequenceClassification.from_pretrained(local_model_path) # use score head instead of lm_head - state_dict = model.state_dict() - state_dict['lm_head.weight'] = state_dict['score.weight'] - state_dict['model.embed_tokens.weight'] = state_dict[ - 'model.embed_tokens.weight'][:32000] # workaround, 32001 -> 32000 - is_value_model = True - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - model = AutoModelForCausalLM.from_pretrained(local_model_path) - state_dict = model.state_dict() - - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - load_state_dict_to_megatron_gptmodel(state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model) - del state_dict, model - - -def get_parallel_gptmodel_from_config(tfconfig, - hf_config, - pre_process=None, - post_process=None, - share_embeddings_and_output_weights=False, - value=False): - from megatron.core.models.gpt.gpt_model import GPTModel - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec - from megatron.core import parallel_state as mpu - from megatron.core import tensor_parallel - use_te = True - assert tfconfig.normalization == "RMSNorm", 'only RMSNorm is supported for now' - transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te) - rope_scaling_args = {} - if hf_config.rope_scaling is not None: - assert hf_config.rope_scaling['type'] == 'linear', "only linear scaling is supported for now" - rope_scaling_args['seq_len_interpolation_factor'] = hf_config.rope_scaling['factor'] - parallel_model = GPTModel(config=tfconfig, - transformer_layer_spec=transformer_layer_spec, - vocab_size=hf_config.vocab_size, - max_sequence_length=hf_config.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type='rope', - rotary_base=hf_config.rope_theta, - **rope_scaling_args) - # # for layer in parallel_model.decoder.layers: layer.self_attention.core_attention.flash_attention.softmax_scale = None - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, - output_size=1, - config=tfconfig) - return parallel_model diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 832e2b92630..7c9f3701ad9 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -162,7 +162,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: def compute_logprobs_fn(output, data): response = data['responses'] response_length = response.size(1) - logits = output + logits = output['logits'] logits = logits[:, -response_length - 1:-1].contiguous() log_probs = vocab_parallel_log_probs_from_logits(logits, response) return {'log_probs': log_probs} @@ -264,7 +264,7 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce def loss_func(output, data, meta_info): if forward_only: if post_process_fn is None: - return 1.0, {'logits': output} + return 1.0, {'logits': output.logits} else: return 1.0, post_process_fn(output, data) @@ -280,7 +280,7 @@ def loss_func(output, data, meta_info): clip_ratio_c = meta_info['clip_ratio_c'] # compute policy loss - logits = output + logits = output.logits logits = logits[:, -response_length - 1:-1].contiguous() logits_back = logits.clone() log_prob = vocab_parallel_log_probs_from_logits(logits, responses) @@ -323,13 +323,7 @@ def forward_step(batch_iter, model): input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] position_ids = batch['position_ids'] - from verl.models.mcore import gptmodel_forward - - output = gptmodel_forward(model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.megatron_config.sequence_parallel) + output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) if forward_only: meta_info = None else: diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 654073c2824..a385a4e0069 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -130,7 +130,7 @@ def forward_backward_batch(self, data: DataProto, forward_only=False): def loss_func(output, data, meta_info): if forward_only: - return 1.0, {'vpreds': output} + return 1.0, {'vpreds': output.logits} responses = data['responses'] attention_mask = data['attention_mask'] @@ -142,7 +142,7 @@ def loss_func(output, data, meta_info): cliprange_value = self.config.cliprange_value - vpreds = output # (bs, sequence_length) + vpreds = output.logits # (bs, sequence_length) vpreds = vpreds[:, -response_length - 1:-1] vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, @@ -163,15 +163,7 @@ def forward_step(batch_iter, model): input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] position_ids = batch['position_ids'] - from verl.models.mcore import gptmodel_forward - - output = gptmodel_forward(model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.megatron_config.sequence_parallel, - value_model=True) - + output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) return output, partial(loss_func, data=batch, meta_info={}) # batch should be a list of batches inside micro-batches diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 80364279f8b..83ddacb5f50 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -33,7 +33,7 @@ from verl import DataProto from verl.utils.fs import copy_to_local from verl.utils.debug import log_gpu_memory_usage -from verl.utils.model import load_megatron_model_weights, load_megatron_gptmodel_weights +from verl.utils.model import load_megatron_model_weights from verl.utils.flops_counter import FlopsCounter from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager from verl.utils.megatron_utils import init_model_parallel_config @@ -142,7 +142,7 @@ def _build_model_optimizer(self, from verl.utils.megatron.optimizer import get_megatron_optimizer from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.model import print_model_size, update_model_config, get_generation_config - from verl.utils.megatron_utils import get_model, init_megatron_optim_config, convert_config + from verl.utils.megatron_utils import get_model, init_megatron_optim_config from transformers import AutoConfig # Step 1: initialize the tokenizer @@ -168,18 +168,17 @@ def _build_model_optimizer(self, self.share_embeddings_and_output_weights = getattr(actor_model_config, "tie_word_embeddings", False) self.architectures = getattr(actor_model_config, "architectures", None) - tfconfig = convert_config(actor_model_config, megatron_config) - self.hf_config = actor_model_config - print(f'TF config: {tfconfig}') - self.hf_config = actor_model_config - def megatron_actor_model_provider(pre_process, post_process): - from verl.utils.model import get_parallel_gptmodel_from_config - parallel_model = get_parallel_gptmodel_from_config( - tfconfig, - actor_model_config, - pre_process, - post_process, + from verl.utils.model import get_parallel_model_from_config + # vpp is not supported yet because it will hang for some reason. Need debugging + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model + # this_megatron_config = copy.deepcopy(megatron_config) + # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank + parallel_model = get_parallel_model_from_config( + config=actor_model_config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, value=False) parallel_model.cuda() @@ -200,11 +199,11 @@ def megatron_actor_model_provider(pre_process, post_process): actor_module = actor_modules_list print(f'actor_module: {len(actor_module)}') if self.config.actor.load_weight: - load_megatron_gptmodel_weights(self.config, - actor_model_config, - actor_module, - params_dtype=megatron_config.params_dtype, - is_value_model=False) + self.hf_config = load_megatron_model_weights(self.config, + actor_model_config, + actor_module, + params_dtype=megatron_config.params_dtype, + is_value_model=False) if self.rank == 0: print_model_size(actor_module[0]) @@ -220,11 +219,11 @@ def megatron_actor_model_provider(pre_process, post_process): if self.config.ref.load_weight: # should align with the actor: assert self.config.actor.load_weight == self.config.ref.load_weight print(f'load ref weight start') - load_megatron_gptmodel_weights(self.config, - actor_model_config, - ref_module, - params_dtype=megatron_config.params_dtype, - is_value_model=False) + self.hf_config = load_megatron_model_weights(self.config, + actor_model_config, + ref_module, + params_dtype=megatron_config.params_dtype, + is_value_model=False) log_gpu_memory_usage('After ref module init', logger=logger) return ref_module, actor_model_config @@ -249,8 +248,10 @@ def _build_rollout(self): # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, # we will reorganize their weight format when resharding from actor to rollout. layer_name_mapping = { - "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.weight", + "qkv_layer_name": + self.config.rollout.layer_name_map.get("qkv_layer_name", "qkv"), + "gate_proj_layer_name": + self.config.rollout.layer_name_map.get("gate_proj_layer_name", "linear_fc1.weight"), } # reshard the weight partition from actor to rollout to initialize the rollout class @@ -521,7 +522,7 @@ def _build_critic_model_optimizer(self, from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.model import print_model_size, update_model_config from verl.utils.megatron.optimizer import get_megatron_optimizer - from verl.utils.megatron_utils import get_model, init_megatron_optim_config, init_model_parallel_config, convert_config + from verl.utils.megatron_utils import get_model, init_megatron_optim_config, init_model_parallel_config from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # Step 1: initialize the tokenizer @@ -541,19 +542,20 @@ def _build_critic_model_optimizer(self, update_model_config(critic_model_config, override_config_kwargs=override_config_kwargs) self.architectures = getattr(critic_model_config, "architectures", None) if self.rank == 0: - print(f'Model config after override: {critic_model_config}') - tfconfig = convert_config(critic_model_config, megatron_config) - self.hf_config = critic_model_config - print(f'TF config: {tfconfig}') + print(f'Model config after override: critic_model_config {critic_model_config}') def megatron_critic_model_provider(pre_process, post_process): - from verl.utils.model import get_parallel_gptmodel_from_config - parallel_model = get_parallel_gptmodel_from_config(tfconfig, - critic_model_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True) + from verl.utils.model import get_parallel_model_from_config + # TODO: support vpp here + # vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model + # this_megatron_config = copy.deepcopy(megatron_config) + # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank + parallel_model = get_parallel_model_from_config(config=critic_model_config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=False, + value=True) parallel_model.cuda() return parallel_model @@ -567,11 +569,11 @@ def megatron_critic_model_provider(pre_process, post_process): # critic_module = nn.ModuleList(critic_module) if self.config.load_weight: - load_megatron_gptmodel_weights(self.config, - critic_model_config, - critic_module, - params_dtype=megatron_config.params_dtype, - is_value_model=True) + self.hf_config = load_megatron_model_weights(self.config, + critic_model_config, + critic_module, + params_dtype=megatron_config.params_dtype, + is_value_model=True) if self.rank == 0: print_model_size(critic_module[0]) diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index d7c684d5e14..50259793bea 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -293,7 +293,7 @@ def default_tp_concat_fn(self, name, param, infer_params, model_config): we can throw an error to force user disable TP HybridEngine. """ - if self.layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: + if self.layer_name_mapping.get("qkv_layer_name") in name: # if the tensor is qkv, for each param on tp, split into q, k, v # concat q, k, v separately. q_lst = [] @@ -305,18 +305,10 @@ def default_tp_concat_fn(self, name, param, infer_params, model_config): kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] for infer_param in infer_params: - num_query_groups_per_partition = model_config.num_key_value_heads // mpu.get_tensor_model_parallel_world_size( - ) - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) + q, k, v = infer_param.split(split_size) + q_lst.append(q) + k_lst.append(k) + v_lst.append(v) q = torch.cat(q_lst, dim=0) k = torch.cat(k_lst, dim=0) v = torch.cat(v_lst, dim=0) @@ -351,16 +343,16 @@ def _post_process_params(self, params): micro_dp_size = get_micro_data_parallel_world_size() micro_dp_group = get_micro_data_parallel_group() + if micro_dp_size <= 1: + return + origin_params = {} for name in params.keys(): param = params[name] if tp_utils.is_tensor_parallel_param(param): # allocate a new tensor with proper size - if micro_dp_size <= 1: - infer_params = [param] - else: - infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)] - torch.distributed.all_gather(infer_params, param, group=micro_dp_group) + infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)] + torch.distributed.all_gather(infer_params, param, group=micro_dp_group) infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config) # replace with original param params[name] = infer_params @@ -389,9 +381,9 @@ def __exit__(self, exc_type, exc_value, traceback): # FIXME(sgm): the best practice is to delete the cuda tensor # rebind the model weights, can be any cpu tensor - if self.origin_params is not None: - for name, param in self.origin_params.items(): - self.params[name] = param + if get_micro_data_parallel_world_size() > 1: + for name in self.params.keys(): + self.params[name] = self.origin_params[name] # self.inference_engine.sync_model_weights(params) self.inference_engine.offload_model_weights()