Skip to content

Evict swa kv cache during decoding#17220

Merged
ispobock merged 14 commits intomainfrom
ke/evict-swa
Jan 19, 2026
Merged

Evict swa kv cache during decoding#17220
ispobock merged 14 commits intomainfrom
ke/evict-swa

Conversation

@ispobock
Copy link
Collaborator

@ispobock ispobock commented Jan 16, 2026

Motivation

  • Before: SWARadixCache only inserts new tokens into the radix tree when (chunked) prefill finishes or a whole request finishes, so out of window tokens can’t be evicted during decoding.
  • Issue: Long generations waste memory and limit swa’s benefit.
  • Now: Only cache the last sliding_window_size tokens for finished requests (for multi-turn prefix reuse). Out of window tokens can be evicted early without hurting cache hit rate.

Modifications

  • Unify SWA eviction for SWARadixCache and SWAChunkCache. Since these tokens are not inserted into RadixCache during decoding, move the eviction logic from the cache implementation into the schedule batch.
  • Update cache_finished_req to consider the evicted tokens.

Benchmarking and Profiling

python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-V2-Flash --tp-size 4 --trust-remote-code --mem-fraction-static 0.7 --model-loader-extra-config '{"enable_multithread_load": "true","num_threads": 64}' --attention-backend fa3 --page-size 64 --swa-full-tokens-ratio 0.2

python3 benchmark/hicache/bench_multiturn.py --model-path XiaomiMiMo/MiMo-V2-Flash --disable-random-sample \
--output-length 1 --request-length 300 \
--num-clients 10 --num-rounds 10 --max-parallel 10 --request-rate 16 \
--ready-queue-policy random --disable-auto-run --enable-round-barrier --output-length 3000

main:

[2026-01-19 02:58:09 TP0] Decode batch, #running-req: 10, #full token: 330688, full token usage: 0.37, #swa token: 48384, swa token usage: 0.27, cuda graph: True, gen throughput (token/s): 836.43, #queue-req: 0, 
[2026-01-19 02:58:10 TP0] Decode batch, #running-req: 10, #full token: 331072, full token usage: 0.37, #swa token: 48768, swa token usage: 0.27, cuda graph: True, gen throughput (token/s): 836.62, #queue-req: 0, 
[2026-01-19 02:58:10 TP0] Decode batch, #running-req: 10, #full token: 331392, full token usage: 0.37, #swa token: 49088, swa token usage: 0.27, cuda graph: True, gen throughput (token/s): 837.37, #queue-req: 0, 
[2026-01-19 02:58:11 TP0] Decode batch, #running-req: 10, #full token: 331840, full token usage: 0.37, #swa token: 49536, swa token usage: 0.28, cuda graph: True, gen throughput (token/s): 836.79, #queue-req: 0, 


Performance metrics summary:
  Total requests: 100 at 16.0 requests per second
  Average Prompt Length: 15242.46 tokens
  Average Output Length: 2999.55 tokens
  Average TTFT: 0.49
  P90 TTFT: 1.03
  P99 TTFT: 2.85
  Median TTFT: 0.27
  Max TTFT: 2.85
  Average latency: 36.79
  P90 latency: 38.20
  P99 latency: 38.67
  Median latency: 36.69
  Max latency: 38.67
  Input token throughput: 4104.34 tokens per second
  Output token throughput: 807.69 tokens per second
  Request Throughput: 0.27 requests per second
  Cache Hit Rate: 0.874651
Per-round metrics:
  Round 0: Average TTFT = 0.33s, Cache Hit Rate = 0.000000 (10 requests)
  Round 1: Average TTFT = 0.19s, Cache Hit Rate = 0.827510 (10 requests)
  Round 2: Average TTFT = 0.23s, Cache Hit Rate = 0.922246 (10 requests)
  Round 3: Average TTFT = 0.35s, Cache Hit Rate = 0.914384 (10 requests)
  Round 4: Average TTFT = 0.21s, Cache Hit Rate = 0.954352 (10 requests)
  Round 5: Average TTFT = 0.26s, Cache Hit Rate = 0.963636 (10 requests)
  Round 6: Average TTFT = 0.46s, Cache Hit Rate = 0.856698 (10 requests)
  Round 7: Average TTFT = 0.69s, Cache Hit Rate = 0.874354 (10 requests)
  Round 8: Average TTFT = 0.29s, Cache Hit Rate = 0.987110 (10 requests)
  Round 9: Average TTFT = 1.87s, Cache Hit Rate = 0.691597 (10 requests)

this PR:

[2026-01-19 02:49:00 TP0] Decode batch, #running-req: 10, #full token: 330368, full token usage: 0.37, #swa token: 5248, swa token usage: 0.03, cuda graph: True, gen throughput (token/s): 814.85, #queue-req: 0, 
[2026-01-19 02:49:00 TP0] Decode batch, #running-req: 10, #full token: 330624, full token usage: 0.37, #swa token: 5376, swa token usage: 0.03, cuda graph: True, gen throughput (token/s): 813.99, #queue-req: 0, 
[2026-01-19 02:49:01 TP0] Decode batch, #running-req: 10, #full token: 331072, full token usage: 0.37, #swa token: 5824, swa token usage: 0.03, cuda graph: True, gen throughput (token/s): 814.25, #queue-req: 0, 
[2026-01-19 02:49:01 TP0] Decode batch, #running-req: 10, #full token: 331584, full token usage: 0.37, #swa token: 6336, swa token usage: 0.04, cuda graph: True, gen throughput (token/s): 813.87, #queue-req: 0, 
[2026-01-19 02:49:02 TP0] Decode batch, #running-req: 10, #full token: 331904, full token usage: 0.37, #swa token: 5376, swa token usage: 0.03, cuda graph: True, gen throughput (token/s): 815.36, #queue-req: 0,

All requests completed
Performance metrics summary:
  Total requests: 100 at 16.0 requests per second
  Average Prompt Length: 15253.59 tokens
  Average Output Length: 2999.55 tokens
  Average TTFT: 0.25
  P90 TTFT: 0.40
  P99 TTFT: 0.75
  Median TTFT: 0.23
  Max TTFT: 0.75
  Average latency: 37.21
  P90 latency: 38.12
  P99 latency: 38.31
  Median latency: 37.41
  Max latency: 38.31
  Input token throughput: 4057.60 tokens per second
  Output token throughput: 797.91 tokens per second
  Request Throughput: 0.27 requests per second
  Cache Hit Rate: 0.965354
Per-round metrics:
  Round 0: Average TTFT = 0.16s, Cache Hit Rate = 0.000000 (10 requests)
  Round 1: Average TTFT = 0.16s, Cache Hit Rate = 0.818212 (10 requests)
  Round 2: Average TTFT = 0.22s, Cache Hit Rate = 0.861090 (10 requests)
  Round 3: Average TTFT = 0.29s, Cache Hit Rate = 0.966613 (10 requests)
  Round 4: Average TTFT = 0.23s, Cache Hit Rate = 0.974975 (10 requests)
  Round 5: Average TTFT = 0.28s, Cache Hit Rate = 0.979822 (10 requests)
  Round 6: Average TTFT = 0.26s, Cache Hit Rate = 0.967138 (10 requests)
  Round 7: Average TTFT = 0.21s, Cache Hit Rate = 0.971974 (10 requests)
  Round 8: Average TTFT = 0.38s, Cache Hit Rate = 0.976123 (10 requests)
  Round 9: Average TTFT = 0.28s, Cache Hit Rate = 0.988668 (10 requests)

This PR can keep much lower swa token usage during decoding and have better cache hit rate in muti-turn benchmark.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the memory management strategy for Sliding Window Attention (SWA) KV caches. By centralizing and refining the eviction logic within the scheduling batch, it ensures that only the most relevant tokens within the sliding window are retained, thereby significantly reducing memory consumption during long token generation processes. This change enhances the efficiency and scalability of SWA-enabled models by preventing unnecessary memory accumulation.

Highlights

  • Improved SWA KV Cache Eviction: The pull request introduces early eviction of out-of-window tokens from the Sliding Window Attention (SWA) KV cache during the decoding phase, addressing previous memory waste issues during long generation sequences.
  • Unified Eviction Logic: The SWA eviction logic has been unified for both SWARadixCache and SWAChunkCache and moved from the individual cache implementations to the schedule_batch manager, centralizing control.
  • Enhanced Radix Cache Insertion: The SWARadixCache's insertion mechanism has been updated to consider swa_evicted_seqlen, allowing for more granular handling of tokens that fall outside the sliding window during insertion or re-insertion of finished requests.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the SWA KV cache eviction logic by centralizing it in ScheduleBatch, which is a good improvement for code clarity and consistency. The changes allow for evicting out-of-window tokens during the decoding phase, which should improve memory efficiency for long generations. The new logic in SWARadixCache to handle partially evicted sequences is complex but seems correct. I have a few suggestions to improve maintainability and robustness.

@hzh0425 hzh0425 self-assigned this Jan 16, 2026
key: RadixKey,
value=None,
prev_prefix_len: int = 0,
swa_evicted_seqlen: int = 0,
Copy link
Collaborator

@hanming-lu hanming-lu Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. you are actually evicting them instead of unlocking

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the generated tokens are not in the tree until the sequence finished. So we cannot unlock it.

assert (
swa_evicted_seqlen % self.page_size == 0
), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}"
if swa_evicted_seqlen <= total_prefix_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments for each branch please

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly paid attention to swa radix cache logic and evict frequency. Looks good. Added some comments, mainly on naming and add some comments please.

Also, please share how you tested the change, and add test coverage for the long decode. I am not aware of any test covering the swa radix cache? Maybe we can remove some self.disable_hybrid_swa_memory = True in the server_args? e.g. GptOssForCausalLM

# We set evict_swa condition here with two reasons:
# 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running.
# 2. Evict swa every window_size tokens to reduce the overhead.
if req.decode_batch_idx % sliding_window_size == 1:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat

@ispobock
Copy link
Collaborator Author

/tag-and-rerun-ci

@ispobock
Copy link
Collaborator Author

@sitabulaixizawaluduo
Copy link
Contributor

Is this optimization intended for non-PD Disaggregation scenarios? I see that there are already corresponding optimizations for PD disaggregation scenarios.

@yizhang2077
Copy link
Collaborator

Is this optimization intended for non-PD Disaggregation scenarios? I see that there are already corresponding optimizations for PD disaggregation scenarios.

I think it is mainly optimized for non-PD scenarios. This optimization is mainly focused on long-decode cases with radix-cache. While in PD settings decode node's radix cache is closed and the swa optimization in Chunk cache has been implemented

@ispobock ispobock merged commit ce8a6ac into main Jan 19, 2026
89 of 94 checks passed
@ispobock ispobock deleted the ke/evict-swa branch January 19, 2026 14:36
DotSlash-A pushed a commit to DotSlash-A/sglang that referenced this pull request Jan 19, 2026
* fix(ci): recover from corrupted MMMU parquet cache (sgl-project#17256)

* [diffusion] feat: support default 4-step inference for Flux2-Klein distilled models (sgl-project#17225)

Signed-off-by: Lancer <maruixiang6688@gmail.com>

* Add runner utilization report workflow (sgl-project#17234)

* cli: support sglang version (sgl-project#17250)

* Use swa radix cache and memory pool for gpt-oss model (sgl-project#17261)

* [VLM][Reland] Refactor load_mm_data to improve performance (sgl-project#16152)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [Tiny] Improve docs (sgl-project#17264)

* [diffusion] fix: set guidance_scale default to None (sgl-project#17182)

* Tiny fix comment typo (sgl-project#17287)

* [SPEC_V2] Enable cudagraph draft_extend for trtllm_mla_backend and Acclen Fix for DP under cudagraph mode (sgl-project#16974)

* Add kl test for swa radix cache (sgl-project#17281)

* fix: Handle multiple named chat templates in HuggingFace tokenizers (sgl-project#17236)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>

* Move radix cache related tests (sgl-project#17295)

* [Refactor] Add `-fp4-gemm-backend` to replace `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (sgl-project#16534)

Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>

* [Bugfix] Fix PD accuracy when MTP is not configured on the prefill node (sgl-project#17212)

Co-authored-by: Shangming Cai <csmthu@gmail.com>

* [Diffusion] Apply jit qk_norm to flux1 (sgl-project#17296)

* [Refactor] Split out deepseek v2 weight loader function into mixin (sgl-project#16649)

* [NPU]Support GPT-OSS for NPU (sgl-project#14197)

* [jit-kernel] Add CuTe DSL GDN Decode Kernel (sgl-project#15631)

Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>

* [GLM 4.7] Add RTX 6000 Pro aka sm120 (sgl-project#17235)

Co-authored-by: root <root@ubuntu-nvidia.localdomain>

* Update CODEOWNERS for multimodal_gen (sgl-project#17308)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>

* [Feature] overlap LoRA weight loading with compute (sgl-project#15512)

* [PD] Optimize MHA models pp util calculation logic (sgl-project#17306)

* [Minor] Correct sglang version when installing from source (sgl-project#17315)

* Use dsv3 optimized routing `fused_topk_deepseek` instead of `moe_fused_gate` (sgl-project#15347)

* [DeepSeek v3.2] Opt MTP decode cuda batch sizes and nsa implementation (sgl-project#16961)

* Update code sync scripts (sgl-project#17319)

* [Auto Sync] Update tokenizer_manager.py (20260119) (sgl-project#17317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* support new qwen3_coder_detector (sgl-project#16744)

Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>

* Fix kernel selection in biased_grouped_topk_gpu (sgl-project#17325)

* KV Cache Events with Attention DP bug fix (sgl-project#16030) (sgl-project#16412)

* [Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)

Co-authored-by: Minglei Zhu <zminglei@linkedin.com>

* [CI] Add partition to stage-b-test-large-1-gpu (11->12) (sgl-project#17245)

* fix(ci): rate limit and permission errors in trace publishing (sgl-project#17238)

* Revert "[Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)" (sgl-project#17332)

* Migrate performance, accuracy, and quantization tests to CI registry (sgl-project#17177)

Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>

* Inclusion of nvfp4 blockscale in EPLB Rebalance (sgl-project#17158)

* [Refactor] Set `fp4-gemm-backend=auto` on SM100 and rename `fp4-gemm-backend` with `flashinfer_` prefix (sgl-project#17309)

* [Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual) (sgl-project#17305)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix v32 continue_final_message not work (sgl-project#16567)

* Evict swa kv cache during decoding (sgl-project#17220)

* [RadixTree][1/N Refactor]: Support unified match_prefix params (sgl-project#17142)

Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>

* [AMD CI] Migrate and Add More Testcases (sgl-project#17116)

Co-authored-by: yctseng0211 <yctseng@amd.com>

* [AMD] CI - add partitions for stage-b-test-small-1-gpu-amd (sgl-project#17345)

* Restore deepseek_v2.py to main's code, except the utils

* Ran `pre-commit`

---------

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Hudson Xing <1277646412@qq.com>
Co-authored-by: Lancer <402430575@qq.com>
Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Mohammad Miadh Angkad <mangkad.bsdsba2027@aim.edu>
Co-authored-by: Changyi Yang <112288487+ChangyiYang@users.noreply.github.com>
Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>
Co-authored-by: Ch3ngY1 <91232537+Ch3ngY1@users.noreply.github.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Jerry Ji <jerryjilol@gmail.com>
Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>
Co-authored-by: Koushik Dutta <koush@koushikdutta.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Glen Liu <62917497+glenliu21@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Lee Nau <lnau@nvidia.com>
Co-authored-by: Yongfei Xu <xuyongfei.xyf@antgroup.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Gaoji Liu <34803073+attack204@users.noreply.github.com>
Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>
Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com>
Co-authored-by: Kartik Ramesh <kartikx2000@gmail.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: Minglei Zhu <zminglei@linkedin.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
Co-authored-by: zhangheng <hzh0425@apache.org>
Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: Bingxu Chen <Bingxu.Chen@amd.com>
Co-authored-by: yctseng0211 <yctseng@amd.com>
michaelzhang-ai added a commit to michaelzhang-ai/sglang that referenced this pull request Feb 2, 2026
michaelzhang-ai added a commit to michaelzhang-ai/sglang that referenced this pull request Feb 2, 2026
michaelzhang-ai added a commit to michaelzhang-ai/sglang that referenced this pull request Feb 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants