Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,12 @@ def __post_init__(self):

@dataclass
class ContinueGenerationReqInput(BaseReq):
pass
# Call torch.cuda.empty_cache() before un-pausing. Returns blocks
# cached by the PyTorch allocator (left over from transient allocs
# during post-weight-update processing) back to the driver before
# inference resumes, with no race against active streams. Set to
# False to skip the empty_cache call.
torch_empty_cache: bool = True


@dataclass
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3646,6 +3646,15 @@ def pause_generation(self, recv_req: PauseGenerationReqInput):
self.chunked_req = None

def continue_generation(self, recv_req: ContinueGenerationReqInput):
if recv_req.torch_empty_cache:
before_mb = torch.cuda.memory_reserved() / (1024 * 1024)
torch.cuda.empty_cache()
after_mb = torch.cuda.memory_reserved() / (1024 * 1024)
logger.info(
f"[continue_generation] torch.cuda.empty_cache() called: "
f"reserved {before_mb:.1f} MB -> {after_mb:.1f} MB "
f"(freed {before_mb - after_mb:.1f} MB)"
)
self._engine_paused = False

def load_lora_adapter(
Expand Down
Loading