diff --git a/verl/utils/model.py b/verl/utils/model.py index 3074ea0c097..11b944c6249 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -16,6 +16,7 @@ """ import os +import re import warnings from typing import Dict, Optional, Type @@ -28,6 +29,7 @@ GenerationConfig, MistralForSequenceClassification, PretrainedConfig, + PreTrainedModel, ) from verl.models.registry import ModelRegistry @@ -205,6 +207,27 @@ def compute_position_id_with_mask(mask): return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) +def convert_weight_keys(state_dict: Dict[str, torch.Tensor], model: PreTrainedModel): + # convert state dict keys: https://github.com/huggingface/transformers/pull/38385 + if not hasattr(model, "_checkpoint_conversion_mapping"): + return state_dict + + reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()} + original_weights = {} + for key, value in state_dict.items(): + for pattern, replacement in reverse_key_mapping.items(): + replacement = replacement.lstrip("^") # strip off un-needed chars and patterns + replacement = re.sub(r"\(.*\)", "", replacement) + key, n_replace = re.subn(pattern, replacement, key) + # Early exit of the loop + if n_replace > 0: + break + + original_weights[key] = value + + return original_weights + + def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): """ Transform the model name in each model_chunk in each pp stage into the name in inference engine @@ -289,7 +312,7 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): from verl.utils.fs import copy_to_local print(f"start download from {config.model.path}") - local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get('use_shm', False)) + local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False)) print("finish download") else: local_model_path = config.model.path diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 86ac2e80909..540b7a012c7 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -134,11 +134,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf trust_remote_code = kwargs.get("trust_remote_code", False) load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format - limit_mm_per_prompt = None - if config.get("limit_images", None): # support for multi-image data - limit_mm_per_prompt = {"image": config.get("limit_images")} - - lora_kwargs = kwargs.pop('lora_kwargs', {}) + lora_kwargs = kwargs.pop("lora_kwargs", {}) self.lora_kwargs = lora_kwargs # copy it to avoid secretly modifying the engine config engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) @@ -147,6 +143,9 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf # (which can vary across different vLLM versions); # - Otherwise it's the desired value we want to explicitly set. engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} + if config.get("limit_images", None): # support for multi-image data + engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")} + self.inference_engine = LLM( model=model_path, enable_sleep_mode=True, @@ -157,7 +156,6 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf gpu_memory_utilization=config.gpu_memory_utilization, disable_custom_all_reduce=True, disable_mm_preprocessor_cache=True, - limit_mm_per_prompt=limit_mm_per_prompt, skip_tokenizer_init=False, max_model_len=max_model_len, load_format=load_format, @@ -280,8 +278,8 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: if self.lora_kwargs: lora_int_ids = list(self.inference_engine.llm_engine.list_loras()) if len(lora_int_ids) > 0: - lora_int_id=lora_int_ids[0] - lora_requests = [LoRARequest(lora_name=f"{lora_int_id}",lora_int_id=lora_int_id,lora_path="/simon-stub-path")] * batch_size + lora_int_id = lora_int_ids[0] + lora_requests = [LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/simon-stub-path")] * batch_size # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): @@ -342,7 +340,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - 'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor + "rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 7b256660068..bcafa2abbb3 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -33,6 +33,7 @@ from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage from verl.utils.debug.performance import _timer from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.model import convert_weight_keys from verl.utils.torch_functional import check_device_is_available from .base import BaseShardingManager @@ -102,6 +103,7 @@ def __enter__(self): log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()} + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) # Copy, not share memory loop = asyncio.get_event_loop() loop.run_until_complete(self.update_weights(params)) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index b440597bc25..4b8c49f4b27 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -39,6 +39,7 @@ from verl.utils.debug.performance import _timer from verl.utils.device import get_torch_device from verl.utils.fsdp_utils import fsdp_version, layered_summon_lora_params, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.model import convert_weight_keys from verl.utils.torch_functional import check_device_is_available from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader @@ -167,6 +168,7 @@ def __collect_lora_params() -> OrderedDict: params = __collect_lora_params() else: params = self.module.state_dict() + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory