diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index eec297d2862..be74bbd41b8 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -32,6 +32,7 @@ from verl.protocol import all_gather_data_proto from verl.utils.device import get_device_id, get_torch_device from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.model import convert_weight_keys from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.torch_functional import check_device_is_available @@ -145,6 +146,15 @@ async def release_memory(self): @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) async def wake_up(self): get_torch_device().empty_cache() + + if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if self.multi_stage_wake_up: + await self.inference_engine.resume_memory_occupation(tags=["weights"]) + log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + else: + await self.inference_engine.resume_memory_occupation() + log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) if self.offload_param: load_fsdp_model_to_gpu(self.module) @@ -155,13 +165,8 @@ async def wake_up(self): k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() } - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - if self.multi_stage_wake_up: - await self.inference_engine.resume_memory_occupation(tags=["weights"]) - log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) - else: - await self.inference_engine.resume_memory_occupation() - log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + # convert weight keys to match the model config + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) # Copy, not share memory await self.update_weights(params)