Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import os
import re
import warnings
from typing import Dict, Optional, Type

Expand All @@ -28,6 +29,7 @@
GenerationConfig,
MistralForSequenceClassification,
PretrainedConfig,
PreTrainedModel,
)

from verl.models.registry import ModelRegistry
Expand Down Expand Up @@ -205,6 +207,27 @@ def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)


def convert_weight_keys(state_dict: Dict[str, torch.Tensor], model: PreTrainedModel):
# convert state dict keys: https://github.com/huggingface/transformers/pull/38385
if not hasattr(model, "_checkpoint_conversion_mapping"):
return state_dict

reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()}
original_weights = {}
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break

original_weights[key] = value

return original_weights


def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"):
"""
Transform the model name in each model_chunk in each pp stage into the name in inference engine
Expand Down Expand Up @@ -289,7 +312,7 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
from verl.utils.fs import copy_to_local

print(f"start download from {config.model.path}")
local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get('use_shm', False))
local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False))
print("finish download")
else:
local_model_path = config.model.path
Expand Down
16 changes: 7 additions & 9 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf
trust_remote_code = kwargs.get("trust_remote_code", False)
load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format

limit_mm_per_prompt = None
if config.get("limit_images", None): # support for multi-image data
limit_mm_per_prompt = {"image": config.get("limit_images")}

lora_kwargs = kwargs.pop('lora_kwargs', {})
lora_kwargs = kwargs.pop("lora_kwargs", {})
self.lora_kwargs = lora_kwargs
# copy it to avoid secretly modifying the engine config
engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm))
Expand All @@ -147,6 +143,9 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf
# (which can vary across different vLLM versions);
# - Otherwise it's the desired value we want to explicitly set.
engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}
if config.get("limit_images", None): # support for multi-image data
engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")}

self.inference_engine = LLM(
model=model_path,
enable_sleep_mode=True,
Expand All @@ -157,7 +156,6 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf
gpu_memory_utilization=config.gpu_memory_utilization,
disable_custom_all_reduce=True,
disable_mm_preprocessor_cache=True,
limit_mm_per_prompt=limit_mm_per_prompt,
skip_tokenizer_init=False,
max_model_len=max_model_len,
load_format=load_format,
Expand Down Expand Up @@ -280,8 +278,8 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
if self.lora_kwargs:
lora_int_ids = list(self.inference_engine.llm_engine.list_loras())
if len(lora_int_ids) > 0:
lora_int_id=lora_int_ids[0]
lora_requests = [LoRARequest(lora_name=f"{lora_int_id}",lora_int_id=lora_int_id,lora_path="/simon-stub-path")] * batch_size
lora_int_id = lora_int_ids[0]
lora_requests = [LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/simon-stub-path")] * batch_size

# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
Expand Down Expand Up @@ -342,7 +340,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
"prompts": idx,
"responses": response,
"input_ids": seq, # here input_ids become the whole sentences
'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor
"rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor
"attention_mask": attention_mask,
"position_ids": position_ids,
},
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage
from verl.utils.debug.performance import _timer
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from verl.utils.model import convert_weight_keys
from verl.utils.torch_functional import check_device_is_available

from .base import BaseShardingManager
Expand Down Expand Up @@ -102,6 +103,7 @@ def __enter__(self):
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy
params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
# Copy, not share memory
loop = asyncio.get_event_loop()
loop.run_until_complete(self.update_weights(params))
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from verl.utils.debug.performance import _timer
from verl.utils.device import get_torch_device
from verl.utils.fsdp_utils import fsdp_version, layered_summon_lora_params, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from verl.utils.model import convert_weight_keys
from verl.utils.torch_functional import check_device_is_available
from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader

Expand Down Expand Up @@ -167,6 +168,7 @@ def __collect_lora_params() -> OrderedDict:
params = __collect_lora_params()
else:
params = self.module.state_dict()
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)

# Copy, not share memory
Expand Down
Loading