From 805f10edefe888086c23319b92d3573539eeeb51 Mon Sep 17 00:00:00 2001 From: "hzji210@gmail.com" Date: Sat, 29 Nov 2025 14:36:22 +0800 Subject: [PATCH 1/2] clean up unused sharding managers --- verl/trainer/main_ppo.py | 2 +- verl/workers/sharding_manager/fsdp_vllm.py | 355 ------------------ .../sharding_manager/megatron_sglang.py | 227 ----------- .../workers/sharding_manager/megatron_vllm.py | 227 ----------- 4 files changed, 1 insertion(+), 810 deletions(-) delete mode 100644 verl/workers/sharding_manager/fsdp_vllm.py delete mode 100644 verl/workers/sharding_manager/megatron_sglang.py delete mode 100644 verl/workers/sharding_manager/megatron_vllm.py diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index d0474d9ab01..ff424abc158 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -248,7 +248,7 @@ def add_ref_policy_worker(self, config, ref_policy_cls): from verl.trainer.ppo.ray_trainer import Role # Ref policy has been fused into ActorRolloutRefWorker in new model engine, - # we don't need to add a separate ref policy worker goup. + # we don't need to add a separate ref policy worker group. use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") if use_legacy_worker_impl == "disable": return diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py deleted file mode 100644 index fb6d7469512..00000000000 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -import os -import time -from collections import OrderedDict - -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from dataclasses import asdict - -from verl import DataProto -from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL -from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.device import get_device_id, get_device_name, get_torch_device, set_expandable_segments -from verl.utils.fsdp_utils import ( - fsdp_version, - layered_summon_lora_params, - load_fsdp_model_to_gpu, - offload_fsdp_model_to_cpu, -) -from verl.utils.import_utils import deprecated -from verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer -from verl.utils.torch_functional import check_device_is_available -from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge - -from .base import BaseShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -@deprecated() -class FSDPVLLMShardingManager(BaseShardingManager): - """Sharding manager for FSDP models with vLLM inference engine integration. - - Manages parameter synchronization between FSDP training models and vLLM - inference engines, handling both full parameters and LoRA adapters with - efficient memory management and device placement. - """ - - @check_device_is_available() - def __init__( - self, - module: FSDP, - inference_engine: LLM, - model_config, - rollout_config, - full_params: bool = False, - device_mesh: DeviceMesh = None, - offload_param: bool = False, - load_format: str = "dummy_hf", - layered_summon: bool = True, - ): - self.module = module - # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model - self.inference_engine = inference_engine - # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if - # inference_engine else None - - self.model_runner = ( - self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner - if self.inference_engine - else None - ) - - self.model_config = model_config - self.rollout_config = rollout_config - self.device_mesh = device_mesh - self.offload_param = offload_param - self.load_format = load_format - self.layered_summon = layered_summon - - # Full params - self.full_params = full_params - if full_params and fsdp_version(self.module) == 1: - FSDP.set_state_dict_type( - self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() - ) - elif fsdp_version(self.module) == 1: - FSDP.set_state_dict_type( - self.module, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - - self.tp_size = self.device_mesh["infer_tp"].size() - self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() - - # Note that torch_random_states may be different on each dp rank - self.torch_random_states = get_torch_device().get_rng_state() - # get a random rng states - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - self.base_sync_done: bool = "dummy" not in load_format - if is_version_ge(pkg="vllm", minver="0.7.3"): - VLLMHijack.hijack() - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def __enter__(self): - def __collect_lora_params() -> OrderedDict: - """ - collect lora params or full params if base model is not ready in vllm - work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) - """ - from peft.utils.save_and_load import get_peft_model_state_dict - - lora_params = OrderedDict() - peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) - if fsdp_version(self.module) > 0: - if self.layered_summon: - if not self.base_sync_done: - raise ValueError( - "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " - "rollout.load_format=safetensors" - ) - lora_params = layered_summon_lora_params(self.module) - else: - with FSDP.summon_full_params(self.module, writeback=False): - if self.base_sync_done: - lora_params = get_peft_model_state_dict(peft_model) - lora_params = { - name: param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() - for name, param in lora_params.items() - } - else: - model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() - model = model.to("cpu") - for name, param in model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue - name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") - lora_params[name] = ( - param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() - ) - model = model.to(orig_dev) - get_torch_device().empty_cache() - else: - if self.base_sync_done: - lora_params = get_peft_model_state_dict(peft_model) - else: - model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() - model = model.to("cpu") - for name, param in model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue - name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") - lora_params[name] = param.detach().cpu() - model = model.to(orig_dev) - return lora_params - - # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and - # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. - # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory - # to speed up memory allocations. - # - # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management - # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - self.timing = {} - with simple_timer("reshard", self.timing): - get_torch_device().empty_cache() - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_fsdp_model_to_gpu(self.module) - - peft_config = None - peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) - if hasattr(peft_model, "peft_config"): - peft_config = peft_model.peft_config.get("default", None) - params = __collect_lora_params() - else: - params = self.module.state_dict() - params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) - - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) - - # vllm need to set _set_allocator_settings to False - logger.debug("fsdp vllm sharding_manager _set_allocator_settings to False") - set_expandable_segments(False) - - if self.rollout_config.free_cache_engine: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - self.inference_engine.wake_up(tags=["weights"]) - else: - self.inference_engine.wake_up() - - # update model params - self.update_params(params, peft_config=peft_config) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) - del params - get_torch_device().empty_cache() - - if ( - self.rollout_config.free_cache_engine - and "tags" in inspect.signature(self.inference_engine.wake_up).parameters - ): - self.inference_engine.wake_up(tags=["kv_cache"]) - - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - if self.rollout_config.free_cache_engine: - self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL) - - self.module.train() - - # add empty cache after each compute - get_torch_device().empty_cache() - - # _set_allocator_settings to True is required by fsdp2 to avoid oom - logger.debug("fsdp vllm sharding_manager _set_allocator_settings to True") - set_expandable_segments(True) - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - """All gather across tp group to make each rank has identical input.""" - if self.tp_size == 1: - return data - - # TODO: Current impl doesn't consider FSDP with torch micro-dp - group = vllm_ps.get_tensor_model_parallel_group().device_group - - all_gather_data_proto(data=data, process_group=group) - return data - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - """Get chunk data of this tp rank since we do all gather in preprocess.""" - if self.tp_size == 1: - return data - - return data.chunk(chunks=self.tp_size)[self.tp_rank] - - def update_params(self, updated_params, peft_config=None): - """Update model parameters in the vLLM inference engine. - - Synchronizes parameters from the FSDP training model to the vLLM inference - engine, handling both full model parameters and LoRA adapters with proper - device placement and memory management. - - Args: - updated_params (dict): Dictionary of parameter names to tensor values. - peft_config (optional): PEFT configuration for LoRA adapters. - """ - model = self.model_runner.model - if peft_config: - if self.base_sync_done: - lora_int_id = int(time.time_ns() % 0x7FFFFFFF) - lora_reqest = TensorLoRARequest( - lora_name=f"{lora_int_id}", - lora_int_id=lora_int_id, - lora_path="simon_lora_path", - peft_config=asdict(peft_config), - lora_tensors=updated_params, - ) - self.inference_engine.llm_engine.add_lora(lora_reqest) - logger.info(f"vLLM load weights, loaded_params: {len(updated_params)}") - return - else: - - def replace_lora_wrapper(k): - """Replace LoRA parameter keys with base layer equivalents. - - Transforms LoRA parameter names to their corresponding base layer - names for proper weight loading in vLLM when base model sync is not done. - - Args: - k (str): Original parameter key name. - - Returns: - str: Transformed parameter key for base layer. - """ - stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] - if k.endswith(".weight"): - module_k = k[: -len(".weight")] - if check_exclude_modules(peft_config, module_k): - return k - elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( - peft_config, module_k - ): - return f"{module_k}.base_layer.weight" - if k.endswith(".bias"): - module_k = k[: -len(".bias")] - if check_exclude_modules(peft_config, module_k): - return k - elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( - peft_config, module_k - ): - return f"{module_k}.base_layer.bias" - return k - - updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()} - - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(model) - device = get_device_id() # used when fsdp2 set cpu_offload_policy - loaded_params = model.load_weights( - ( - (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) - for name, param in updated_params.items() - ) - ) - - self.base_sync_done = True - logger.info(f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}") diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py deleted file mode 100644 index b945b858c71..00000000000 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. -""" - -import asyncio -import logging -import os - -from omegaconf import DictConfig -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights -from torch import nn -from torch.distributed.device_mesh import DeviceMesh - -from verl.protocol import DataProto, all_gather_data_proto -from verl.utils.device import get_torch_device, set_expandable_segments -from verl.utils.import_utils import deprecated -from verl.utils.megatron_utils import ( - load_megatron_model_to_gpu, - offload_megatron_model_to_cpu, - per_tensor_generator, -) -from verl.utils.memory_utils import aggressive_empty_cache -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer -from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets - -from .base import BaseShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) - - -""" -Megatron Hybrid Engine: -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all - the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - - -@deprecated() -class MegatronSGLangShardingManager(BaseShardingManager): - """A sharding manager for Megatron-style training & inference with SGLang. - - This class manages the sharding of model parameters between training and inference - phases in a Megatron-style parallel setup. It handles: - - Loading/offloading parameters between CPU/GPU - - Updating inference engine weights - - Managing random states for reproducibility - - Data preprocessing for distributed inference - - Args: - actor_module (nn.ModuleList): The actor model modules - inference_engine (Engine): The SGLang inference engine - model_config: Configuration for the actor's model - rollout_config: Configuration for rollout generation - transformer_config: Transformer-specific configuration - layer_name_mapping: Mapping between layer names and parameters - weight_converter: Utility for converting weights between formats - device_mesh (DeviceMesh | None): PyTorch device mesh for distributed training - offload_param (bool): Whether to offload parameters to CPU when not in use - """ - - def __init__( - self, - actor_module: nn.ModuleList, - inference_engine: Engine, - model_config: DictConfig, - rollout_config: DictConfig, - transformer_config, - layer_name_mapping, - weight_converter, - device_mesh: DeviceMesh | None = None, - offload_param: bool = False, - bridge=None, - ): - self.actor_module = actor_module - self.inference_engine = inference_engine - self.model_config = model_config - self.rollout_config = rollout_config - self.transformer_config = transformer_config - self.layer_name_mapping = layer_name_mapping - self.weight_converter = weight_converter - self.device_mesh = device_mesh - self.bridge = bridge - self.offload_param = offload_param - - if self.device_mesh is not None: - self.infer_tp_size = self.device_mesh["infer_tp"].mesh.size()[0] - else: - self.infer_tp_size = self.inference_engine._tp_size - - # Note that torch_random_states may be different on each dp rank - self.torch_random_states = get_torch_device().get_rng_state() - # get a random rng states - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) - def __enter__(self): - self.timing = {} - with simple_timer("reshard", self.timing): - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wake_up()) - - @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - loop = asyncio.get_event_loop() - loop.run_until_complete(self.sleep()) - - async def update_weights(self, params): - """ - Update model weights using tensor buckets, similar to THUDM/slime's implementation. - - Notes: - - For the best performance of `rebuild_cuda_tensor`, it is recommended to: - 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`. - 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` - when using Tensor Parallelism (TP >= 8). - - See reference implementations in SLIME: - - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 - """ - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.resume_memory_occupation() - named_tensors = params - - update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20 - for params_batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes): - await sgl_update_weights( - engine=self.inference_engine, - params_batch=params_batch, - device_mesh_key="infer_tp", - device_mesh=self.device_mesh, - ) - - if self.device_mesh["infer_tp"].get_local_rank() == 0: - await self.inference_engine.flush_cache() - - async def release_memory(self): - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.release_memory_occupation() - - @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) - async def wake_up(self): - aggressive_empty_cache(force_sync=True) - - if self.offload_param: - load_megatron_model_to_gpu(self.actor_module, load_grad=False) - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) - else: - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - - set_expandable_segments(False) - - await self.update_weights(per_tensor_param) - if self.offload_param: - offload_megatron_model_to_cpu(self.actor_module) - aggressive_empty_cache(force_sync=True) - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) - async def sleep(self): - if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - await self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - for model in self.actor_module: - model.train() - # add empty cache after each compute - aggressive_empty_cache(force_sync=True) - - set_expandable_segments(True) - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - all_gather_data_proto(data, self.device_mesh["infer_tp"].get_group()) - return data - - @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["infer_tp"].get_local_rank()] diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py deleted file mode 100644 index 6adc89c0985..00000000000 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. -""" - -import inspect -import logging -import os - -import torch -import torch.distributed -from megatron.core import parallel_state as mpu -from omegaconf import DictConfig -from torch import nn - -from verl import DataProto -from verl.models.mcore.weight_converter import McoreToHFWeightConverterBase -from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL -from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.device import get_torch_device, set_expandable_segments -from verl.utils.import_utils import deprecated -from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator -from verl.utils.memory_utils import aggressive_empty_cache -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.profiler.performance import simple_timer -from verl.utils.torch_functional import check_device_is_available - -from .base import BaseShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -""" -Megatron Hybrid Engine: -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank - to all other pp ranks (all pp ranks holds all the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - - -@deprecated() -class MegatronVLLMShardingManager(BaseShardingManager): - """A sharding manager that bridges Megatron-LM training with vLLM inference. - - This class handles the parameter sharding and communication between: - - Megatron-LM's tensor/expert parallel training setup - - vLLM's tensor parallel inference setup - - Key responsibilities: - - Manages parameter broadcasting between training and inference configurations - - Handles weight conversion between Megatron and HuggingFace formats - - Coordinates memory management between training and inference phases - - Maintains random state consistency across different parallel groups - - Args: - actor_module (nn.ModuleList): The Megatron-LM model being trained - inference_engine (LLM): The vLLM inference engine - model_config: Configuration for the actor's model - transformer_config: Transformer-specific configuration for the model - rollout_config: Configuration for rollout - layer_name_mapping: Mapping between Megatron and HF layer names - weight_converter (McoreToHFWeightConverterBase): Converts weights between formats - device_mesh: Device mesh for parallel operations - offload_param (bool): Whether to offload parameters when not in use - """ - - @check_device_is_available() - def __init__( - self, - actor_module: nn.ModuleList, - inference_engine: LLM, - model_config: DictConfig, - transformer_config, - rollout_config: DictConfig, - layer_name_mapping, - weight_converter: McoreToHFWeightConverterBase, - device_mesh, - offload_param: bool = True, - bridge=None, - ): - self.actor_module = actor_module - self.inference_engine = inference_engine - self.offload_param = offload_param - - # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model - self.model_runner = ( - self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner - if self.inference_engine - else None - ) - - self.model_config = model_config - self.transformer_config = transformer_config - self.rollout_config = rollout_config - self.layer_name_mapping = layer_name_mapping - self.weight_converter = weight_converter - self.bridge = bridge - # initialize groups for vllm inference - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - - self.device_mesh = device_mesh - self.infer_tp_size = self.device_mesh["infer_tp"].size() - self.infer_tp_rank = self.device_mesh["infer_tp"].get_local_rank() - - self.train_tp_size = mpu.get_tensor_model_parallel_world_size() - self.train_tp_rank = mpu.get_tensor_model_parallel_rank() - self.train_tp_group = mpu.get_tensor_model_parallel_group() - self.train_ep_size = mpu.get_expert_model_parallel_world_size() - self.train_ep_rank = mpu.get_expert_model_parallel_rank() - self.train_ep_group = mpu.get_expert_model_parallel_group() - self.train_etp_size = mpu.get_expert_tensor_parallel_world_size() - self.train_etp_rank = mpu.get_expert_tensor_parallel_rank() - self.train_etp_group = mpu.get_expert_tensor_parallel_group() - self.need_tp_reshard = self.train_tp_size != self.infer_tp_size - self.train_tp_larger = self.train_tp_size > self.infer_tp_size - - self.torch_random_states = get_torch_device().get_rng_state() - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def __enter__(self): - self.timing = {} - with simple_timer("reshard", self.timing): - aggressive_empty_cache(force_sync=True) - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_megatron_model_to_gpu(self.actor_module, load_grad=False) - - set_expandable_segments(False) - - if self.rollout_config.free_cache_engine: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - self.inference_engine.wake_up(tags=["weights"]) - else: - self.inference_engine.wake_up() - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) - else: - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - model = self.model_runner.model - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - patch_vllm_moe_model_weight_loader(model) - loaded_params = model.load_weights(per_tensor_param) - info = f"vLLM load weights, loaded_params: {len(loaded_params)}" - logger.info(info) - - if self.offload_param: - offload_megatron_model_to_cpu(self.actor_module) - aggressive_empty_cache(force_sync=True) - - if ( - self.rollout_config.free_cache_engine - and "tags" in inspect.signature(self.inference_engine.wake_up).parameters - ): - self.inference_engine.wake_up(tags=["kv_cache"]) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - if self.rollout_config.free_cache_engine: - self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL) - for model in self.actor_module: - model.train() - - aggressive_empty_cache(force_sync=True) - - set_expandable_segments(True) - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - - # TODO: Current impl doesn't consider FSDP with torch micro-dp - group = vllm_ps.get_tensor_model_parallel_group().device_group - - all_gather_data_proto(data=data, process_group=group) - return data - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - return data.chunk(chunks=self.infer_tp_size)[self.infer_tp_rank] From ca8abbbd7ea87f934d65e5bbff2adc5cc4d82682 Mon Sep 17 00:00:00 2001 From: "hzji210@gmail.com" Date: Mon, 1 Dec 2025 10:36:40 +0800 Subject: [PATCH 2/2] more --- docs/workers/fsdp_workers.rst | 8 ++------ docs/workers/megatron_workers.rst | 9 +-------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/docs/workers/fsdp_workers.rst b/docs/workers/fsdp_workers.rst index b158fb265df..03bde11376c 100644 --- a/docs/workers/fsdp_workers.rst +++ b/docs/workers/fsdp_workers.rst @@ -1,12 +1,10 @@ PyTorch FSDP Backend ====================== -Last updated: 02/12/2025. +Last updated: 12/01/2025. We support PyTorch FSDP Backend by implementing various workers for -actor, critic, reference, rollout and reward models. We also implement -the ``FSDPVLLMShardingManager`` that reshard weight between FSDP and -vLLM in `fsdp_vllm.py `_. +actor, critic, reference, rollout and reward models. **Pros** @@ -58,8 +56,6 @@ highlighted below: 2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM Engine and make it executed under SPMD to fit into our ``WorkerGroup`` design. -3. ``FSDPVLLMShardingManager`` a context manager to perform actual - resharding between actor and rollout. See `source code `_. for more information. diff --git a/docs/workers/megatron_workers.rst b/docs/workers/megatron_workers.rst index bd02836a624..91452c7dc51 100644 --- a/docs/workers/megatron_workers.rst +++ b/docs/workers/megatron_workers.rst @@ -1,7 +1,7 @@ Megatron-LM Backend =================== -Last updated: 06/24/2025. +Last updated: 12/01/2025. We support Megatron Backend by implementing various workers for actor, critic, reference, rollout and reward models. We also implement the @@ -121,8 +121,6 @@ highlighted below: 2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM Engine and make it executed under SPMD to fit into our ``WorkerGroup`` design. -3. ``MegatronVLLMShardingManager`` a context manager to perform actual - resharding between actor and rollout. See `source code `_ for more information. @@ -143,11 +141,6 @@ See `source code