Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from verl.protocol import all_gather_data_proto
from verl.utils.device import get_device_id, get_torch_device
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from verl.utils.torch_functional import check_device_is_available

Expand Down Expand Up @@ -145,6 +146,15 @@ async def release_memory(self):
@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
async def wake_up(self):
get_torch_device().empty_cache()

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.multi_stage_wake_up:
await self.inference_engine.resume_memory_occupation(tags=["weights"])
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
await self.inference_engine.resume_memory_occupation()
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)

log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)
Expand All @@ -155,13 +165,8 @@ async def wake_up(self):
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()
}

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.multi_stage_wake_up:
await self.inference_engine.resume_memory_occupation(tags=["weights"])
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
await self.inference_engine.resume_memory_occupation()
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
# convert weight keys to match the model config
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))

# Copy, not share memory
await self.update_weights(params)
Expand Down