Skip to content

Commit 13da23d

Browse files
author
崔博
committed
fix bug
1 parent 719aa05 commit 13da23d

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
ResumeMemoryOccupationReqInput,
3636
UpdateWeightsFromTensorReqInput,
3737
)
38-
from sglang.srt.openai_api.protocol import Tool
38+
39+
# from sglang.srt.openai_api.protocol import Tool
3940
from sglang.srt.sampling.sampling_params import SamplingParams
4041
from sglang.srt.server_args import ServerArgs
4142
from sglang.srt.utils import (
@@ -135,9 +136,6 @@ def __init__(self, **kwargs):
135136

136137
async def release_memory_occupation(self, tags: Optional[list[str]] = None):
137138
"""Release GPU occupation temporarily."""
138-
if self._need_reload:
139-
await self.release_memory_occupation()
140-
self._need_reload = False
141139
if tags is None:
142140
obj = ReleaseMemoryOccupationReqInput()
143141
else:
@@ -149,7 +147,9 @@ async def resume_memory_occupation(self, tags: Optional[list[str]] = None):
149147
# because __init__ is a sync method, it can not call the async release_memory_occupation
150148
# have to move release_memory_occupation from __init__ to here
151149
# For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.
152-
await self.release_memory_occupation()
150+
if self._need_reload:
151+
await self.release_memory_occupation()
152+
self._need_reload = False
153153

154154
if tags is None:
155155
obj = ResumeMemoryOccupationReqInput()

verl/workers/sharding_manager/megatron_sglang.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def offload_manager(self):
110110
if self.offload_param:
111111
offload_megatron_model_to_cpu(self.actor_module)
112112
get_torch_device().empty_cache()
113+
torch.distributed.barrier()
113114

114115
if self.multi_stage_wake_up:
115116
loop.run_until_complete(self.resume_memory(tags=["kv_cache"]))

0 commit comments

Comments
 (0)