Skip to content

[RL] Call torch.cuda.empty_cache() for in-place pause mode to avoid OOM#24854

Merged
ByronHsu merged 5 commits intosgl-project:mainfrom
ByronHsu:byron/upstream-empty-cache-on-resume
May 10, 2026
Merged

[RL] Call torch.cuda.empty_cache() for in-place pause mode to avoid OOM#24854
ByronHsu merged 5 commits intosgl-project:mainfrom
ByronHsu:byron/upstream-empty-cache-on-resume

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented May 9, 2026

Motivation

Post-weight-update processing (e.g. DeepSeek MLA w_kc/w_vc derivation, FP8 scale rebuild) creates transient CUDA allocations that fragment PyTorch's block cache. Without empty_cache(), reserved memory grows each iteration and eventually OOMs.

Pause mode flush_cache called? empty_cache before this PR empty_cache after this PR
abort Yes Yes (via flush_cache) Yes (via flush_cache + resume)
retract Yes Yes (via flush_cache) Yes (via flush_cache + resume)
in_place No No ← the gap Yes (via resume)

abort and retract are safer because they already call empty_cache() as part of flush_cache. The in_place path never calls flush_cache (to preserve KV cache), so empty_cache() was never triggered — this PR closes that gap.

With this change, abort and retract will call empty_cache() twice (once in flush_cache, once on resume), but the second call is benign — it's a no-op when there are no cached blocks.

Changes

  • Add empty_cache: bool = True to ContinueGenerationReqInput. The scheduler calls torch.cuda.empty_cache() while still paused (no race with active streams).
  • Log reserved-memory delta at INFO level for observability.
  • Callers can opt out with empty_cache=False.

Before

image

OOM after repeated weight updates:

Full traceback
      ret, can_run_graph = self.forward_extend(
                           ^^^^^^^^^^^^^^^^^^^^
    File "sglang/srt/model_executor/model_runner.py", line 2780, in forward_extend
      self.model.forward(
    File "torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "sglang/srt/models/deepseek_v2.py", line 2298, in forward
      hidden_states = self.model(
    File "sglang/srt/models/deepseek_v2.py", line 2061, in forward
      hidden_states, residual, topk_indices = layer(
    File "sglang/srt/models/deepseek_v2.py", line 1750, in forward
      hidden_states = self.mlp(
    File "sglang/srt/models/deepseek_v2.py", line 574, in forward
      return self.forward_deepep(hidden_states, forward_batch)
    File "sglang/srt/models/deepseek_v2.py", line 782, in forward_deepep
      shared_output = self._forward_shared_experts(hidden_states)
    File "sglang/srt/models/deepseek_v2.py", line 979, in _forward_shared_experts
      return self.shared_experts(
    File "sglang/srt/models/deepseek_v2.py", line 258, in forward
      x, _ = self.down_proj(
    File "sglang/srt/layers/linear.py", line 1509, in forward
      output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
    File "sglang/srt/layers/quantization/fp8.py", line 737, in apply
      return self.w8a8_block_fp8_linear(
    File "sglang/srt/layers/quantization/fp8_utils.py", line 678, in deepgemm_w8a8_block_fp8_linear_with_fallback
      output = w8a8_block_fp8_matmul_deepgemm(
    File "sglang/srt/layers/quantization/fp8_kernel.py", line 1107, in w8a8_block_fp8_matmul_deepgemm
      deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
    File "sglang/srt/layers/quantization/fp8_kernel.py", line 123, in deep_gemm_fp8_fp8_bf16_nt
      deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
    File "sglang/srt/layers/deep_gemm_wrapper/entrypoint.py", line 98, in gemm_nt_f8f8bf16
      deep_gemm.fp8_gemm_nt(
    File "deep_gemm/__init__.py", line 50, in _fn
      return func(*args, **kwargs)
  RuntimeError: CUDA driver error (/sgl-kernel/build/_deps/repo-deepgemm-src/csrc/apis/../jit_kernels/impls/../../jit/handle.hpp:84): 2
  (CUDA_ERROR_OUT_OF_MEMORY, out of memory)

After

image

Stable memory, no OOM.

Test plan

  • 1-node Qwen3-30B-A3B agg recipe, in-place pause + routing replay: log fires every resume, ~2 MB reclaimed on first, ~0 MB thereafter.
  • ESS, loss, exit code unchanged.

Post-weight-update code paths (e.g. DeepSeek MLA w_kc/w_vc derivation,
FP8 block-quant scale rebuild) do many alloc/free cycles. Same-shape
later allocations don't always reuse the freed blocks because of
allocator split/merge heuristics and transient peaks during the cycle,
so the PyTorch caching allocator's working footprint grows over
weight-update cycles until it hits steady state. Live tensor count is
stable; the growth is cached-but-unused blocks held by the allocator.

Add an empty_cache field on ContinueGenerationReqInput, defaulted True.
When set, the scheduler calls torch.cuda.empty_cache() while the engine
is still paused, before flipping _engine_paused = False, returning the
cached blocks to the driver with no race against active streams.

An INFO log reports CUDA reserved memory before/after and how much was
freed, making it easy to verify the empty_cache step is firing and to
see how much transient memory it returns.

Set empty_cache=False on the request to opt out.

Co-authored-by: Cursor <cursoragent@cursor.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ByronHsu ByronHsu changed the title [Scheduler] Call torch.cuda.empty_cache() before resuming generation [RL] Call torch.cuda.empty_cache() before resuming generation to avoid memory increase from fragmentation May 9, 2026
@ByronHsu ByronHsu changed the title [RL] Call torch.cuda.empty_cache() before resuming generation to avoid memory increase from fragmentation [RL] Call torch.cuda.empty_cache() for in-place pause mode May 9, 2026
@ByronHsu ByronHsu changed the title [RL] Call torch.cuda.empty_cache() for in-place pause mode [RL] Call torch.cuda.empty_cache() for in-place pause mode to avoid OOM May 9, 2026
Comment thread python/sglang/srt/managers/io_struct.py Outdated
# during post-weight-update processing) back to the driver before
# inference resumes, with no race against active streams. Set to
# False to skip the empty_cache call.
empty_cache: bool = True
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread python/sglang/srt/managers/io_struct.py Outdated
Align with the naming convention used in update_* request structs
(e.g. UpdateWeightsFromDistributedReqInput.torch_empty_cache).

Co-authored-by: Cursor <cursoragent@cursor.com>

def continue_generation(self, recv_req: ContinueGenerationReqInput):
if recv_req.empty_cache:
before_mb = torch.cuda.memory_reserved() / (1024 * 1024)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make it compatible with AMD and other accelerator?

def get_available_gpu_memory(

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a bit too messy. also it contradicts with #24854 (comment). i will keep it simple and just support torch for now

Copy link
Copy Markdown
Collaborator

@hebiao064 hebiao064 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with two minor comments

@ispobock
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Byron Hsu and others added 3 commits May 10, 2026 04:29
Replace direct torch.cuda.empty_cache() / memory_reserved() calls in
continue_generation with the empty_device_cache() helper from sgl-project#24861,
making the in-place pause resume path work on all device backends.

Co-authored-by: Cursor <cursoragent@cursor.com>
@ByronHsu
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@ByronHsu ByronHsu merged commit cfd3fd0 into sgl-project:main May 10, 2026
104 of 133 checks passed
ByronHsu added a commit that referenced this pull request May 10, 2026
…in-place pause mode to avoid OOM (#24905)

Co-authored-by: Byron Hsu <byron@periodiclabs.ai>
Co-authored-by: Cursor <cursoragent@cursor.com>
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 11, 2026
* main: (87 commits)
  [Fix] Disable FlashInfer allreduce fusion under deterministic inference (sgl-project#24629)
  fix: STANDALONE spec-decode hidden-size mismatch crash (sgl-project#24217)
  Followup fix for Custom AR V2 in non NVL scenarios (sgl-project#24742)
  Fix reduce_scatterv producer contract for SUM_LEN (sgl-project#24785)
  [NPU]Documentation update for communications quantization feature (sgl-project#24668)
  [Session R3] Add routed_experts_start_len for absolute routing slice control (sgl-project#24851)
  [Model] Add MiniCPM-V 4.6 support (sgl-project#24855)
  Support Intern-S2-Preview (sgl-project#24875)
  [PD] Unify dsv4 dispatch with swa (sgl-project#24888)
  Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head (sgl-project#24775)
  Fix PD bootstrap failure handling (sgl-project#24772)
  [Spec] Cleanup idle stub and shape-check patterns (sgl-project#24881)
  [Bug] Add dsv4 state_type branch to mooncake disaggregation (sgl-project#24878)
  [Spec V1] Split draft-extend phase from `EagleDraftInput` into new `EagleDraftExtendInput` (sgl-project#24859)
  [Gemma4] Optimize Gemm4 with fused Q/K/V RMSNorm + per-expert FP8 ckpt loader (sgl-project#24696)
  [spec decoding] support kimi-k2.5-eagle3-mla (sgl-project#24826)
  [SPEC V2] fix: skip stale state updates in spec-v2 overlap (sgl-project#23456)
  [RL] Call torch.cuda.empty_cache() for `in-place` pause mode to avoid OOM (sgl-project#24854)
  [diffusion] CI: add cache-dit CI tests (sgl-project#19213)
  [Utils] Make request dump robust to unpicklable server_args and large meta_info (sgl-project#24767)
  ...

# Conflicts:
#	python/sglang/srt/utils/common.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants