From f7691722aae1941467e8a239e0a51cc34566dedc Mon Sep 17 00:00:00 2001 From: Chen Haiquan Date: Fri, 4 Jul 2025 18:20:25 +0800 Subject: [PATCH 1/3] fix sglang async with Multi-stage Awake --- .github/workflows/e2e_ppo_trainer.yml | 4 ++-- verl/workers/sharding_manager/fsdp_sglang.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 70dba08b964..be0e800f58a 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -287,13 +287,13 @@ jobs: - name: Running GSM8K E2E training tests on sglang async run: | ray stop --force - ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh + TOTAL_TRAIN_STEPS=2 ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on vllm async run: | ray stop --force export VLLM_USE_V1=1 ray start --head - ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh + TOTAL_TRAIN_STEPS=2 ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh e2e_ppo_trainer_sglang_multiturn_with_tool: runs-on: [L20x8] diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 02cec3595a9..3d5b9b32a1d 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -207,6 +207,15 @@ async def wake_up(self): params = { k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() } + + if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if self.multi_stage_wake_up: + await self.inference_engine.resume_memory_occupation(tags=["weights"]) + log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + else: + await self.inference_engine.resume_memory_occupation() + log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + # Copy, not share memory await self.update_weights(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) @@ -217,6 +226,10 @@ async def wake_up(self): get_torch_device().empty_cache() log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + if self.multi_stage_wake_up and self.rollout_config.free_cache_engine: + await self.inference_engine.resume_memory_occupation(tags=["kv_cache"]) + log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) + # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = get_torch_device().get_rng_state() From a49c30075ec3662235774ece06231fcb1f40892c Mon Sep 17 00:00:00 2001 From: Chen Haiquan Date: Mon, 7 Jul 2025 10:28:43 +0800 Subject: [PATCH 2/3] resolve duplicated code in __enter__ and __exit__ --- verl/workers/sharding_manager/fsdp_sglang.py | 59 +------------------- 1 file changed, 3 insertions(+), 56 deletions(-) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 3d5b9b32a1d..fb0a5d77d69 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -33,7 +33,6 @@ from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.device import get_device_id, get_torch_device from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu -from verl.utils.model import convert_weight_keys from verl.utils.torch_functional import check_device_is_available from .base import BaseShardingManager @@ -101,65 +100,13 @@ def __init__( def __enter__(self): self.timing = {} with simple_timer("reshard", self.timing): - get_torch_device().empty_cache() - loop = asyncio.get_event_loop() - - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - if self.multi_stage_wake_up: - loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"])) - log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) - else: - loop.run_until_complete(self.inference_engine.resume_memory_occupation()) - log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) - get_torch_device().empty_cache() - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_fsdp_model_to_gpu(self.module) - params = self.module.state_dict() - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) - device = get_device_id() # used when fsdp2 set cpu_offload_policy - params = { - k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() - } - params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) - # Copy, not share memory - loop.run_until_complete(self.update_weights(params)) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) - - del params - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) - get_torch_device().empty_cache() - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - - if self.multi_stage_wake_up and self.rollout_config.free_cache_engine: - loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"])) - log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) + loop.run_until_complete(self.wake_up()) @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - loop = asyncio.get_event_loop() - loop.run_until_complete(self.release_memory()) - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - self.module.train() - - # add empty cache after each compute - get_torch_device().empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) + loop = asyncio.get_event_loop() + loop.run_until_complete(self.sleep()) async def update_weights(self, params): # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update From b4e7ddd15bf4cb54ddda9a0620c58431e855c1d1 Mon Sep 17 00:00:00 2001 From: Chen Haiquan Date: Mon, 7 Jul 2025 11:17:53 +0800 Subject: [PATCH 3/3] fix megatron with sglang async --- verl/workers/megatron_workers.py | 7 ++- .../sharding_manager/megatron_sglang.py | 47 ++++--------------- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 277ebb9733f..bab7991e117 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,7 +19,7 @@ import logging import os import time -from typing import Optional, Union +from typing import Any, Dict, List, Optional, Union import psutil import torch @@ -692,6 +692,11 @@ async def chat_completion(self, json_request): ret = await self.rollout.chat_completion(json_request) return ret + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]: + ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) + return ret + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) async def wake_up(self): if self.config.rollout.free_cache_engine: diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index 61cf8aed4da..d89363afdf9 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -114,45 +114,13 @@ def __init__( def __enter__(self): self.timing = {} with simple_timer("reshard", self.timing): - if self.offload_param: - load_megatron_model_to_gpu(self.actor_module) - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) - else: - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) loop = asyncio.get_event_loop() - loop.run_until_complete(self.update_weights(per_tensor_param)) - if self.offload_param: - offload_megatron_model_to_cpu(self.actor_module) - get_torch_device().empty_cache() - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) + loop.run_until_complete(self.wake_up()) @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - loop = asyncio.get_event_loop() - loop.run_until_complete(self.release_memory()) - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - for model in self.actor_module: - model.train() - # add empty cache after each compute - get_torch_device().empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) + loop = asyncio.get_event_loop() + loop.run_until_complete(self.sleep()) async def update_weights(self, params): if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: @@ -182,8 +150,10 @@ async def release_memory(self): if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: await self.inference_engine.release_memory_occupation() - @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) + @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) async def wake_up(self): + if self.offload_param: + load_megatron_model_to_gpu(self.actor_module) if self.bridge is not None: per_tensor_param = self.bridge.export_weights(self.actor_module) else: @@ -195,12 +165,15 @@ async def wake_up(self): self.layer_name_mapping, ) await self.update_weights(per_tensor_param) + if self.offload_param: + offload_megatron_model_to_cpu(self.actor_module) + get_torch_device().empty_cache() # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.gen_random_states) - @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) + @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) async def sleep(self): if self.rollout_config.free_cache_engine: log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)