Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
35a0016
Add CuTe DSL GDN Verify Kernel Implementation
xutizhou Jan 11, 2026
0e46713
feat: optimize K-last SSM layout for MTP in GDN
xutizhou Jan 12, 2026
40938f8
fix: remove unconditional debug log in forward_decode
xutizhou Jan 12, 2026
046789a
fix: wrap all h0_source memory access inside cache_idx >= 0 check to …
xutizhou Jan 12, 2026
a01e7e2
feat: Add CuTe DSL GDN verify kernel for K-last SSM layout
xutizhou Jan 13, 2026
db17cb0
Optimize GDN attention backend: reduce GPU-CPU sync and pre-allocate …
xutizhou Jan 14, 2026
498d680
feat(gdn): add tensor logging for K-last vs V-last precision comparison
Jan 14, 2026
6577f3f
chore: remove debug tensor logging
Jan 15, 2026
59eafd5
fix: guard k-last decode padding and kv mismatch
xutizhou Jan 15, 2026
72fb76b
perf: move has_prefix_cache computation to init_forward_metadata
xutizhou Jan 15, 2026
9fa9933
fix(gdn-verify): align CuTe K-last verify kernel precision with Trito…
xutizhou Jan 19, 2026
649a8cf
refactor(gdn-verify): remove precomputed g/beta feed for K-last CuTe …
xutizhou Jan 19, 2026
e65cda7
refactor(gdn-verify): simplify shared memory allocation for tensors
xutizhou Jan 23, 2026
102fb47
pass
xutizhou Jan 29, 2026
c231643
fix(gdn-prefill): correct flashinfer K-last state layout
xutizhou Jan 29, 2026
a8e62f0
perf(gdn-prefill): remove GPU-CPU sync in K-last prefill path
xutizhou Jan 30, 2026
9b47393
refactor(gdn): remove K-last Triton transpose fallback
xutizhou Jan 30, 2026
64c57d9
refactor(gdn): make K-last only when K-last kernels are available
xutizhou Jan 30, 2026
bbf0c43
refactor(gdn): trim K-last integration overhead
xutizhou Jan 30, 2026
716f58d
refactor(gdn): remove unused intermediate_state_indices plumbing
xutizhou Jan 30, 2026
fd022c3
feat(benchmark): add K-last vs V-last comparison script
xutizhou Jan 31, 2026
74e2761
feat(benchmark): add torch profiling support to K-last vs V-last comp…
xutizhou Jan 31, 2026
11fb4fe
feat(gdn): add kernel precompilation and comm debugging utilities
xutizhou Feb 2, 2026
2b95039
fix(benchmark): adjust memory fraction and enhance profiling options
xutizhou Feb 2, 2026
c4ec38c
feat(benchmark): enhance profiling capabilities and server launch opt…
xutizhou Feb 2, 2026
3364f66
perf(gdn): optimize K-last prefill path to eliminate cudaFree overhead
xutizhou Feb 2, 2026
8a58c16
refactor(gdn): simplify K-last prefill path by removing unnecessary c…
xutizhou Feb 3, 2026
20dc3c5
clean code
xutizhou Feb 3, 2026
b91a495
code clean
xutizhou Feb 3, 2026
0e6a37f
add flashinfer gdn
xutizhou Feb 5, 2026
54c956b
perf(gdn): add MTP kernel warmup to eliminate JIT compilation overhead
xutizhou Feb 6, 2026
f043e4e
perf(gdn): integrate FlashInfer K-last decode kernel
xutizhou Feb 6, 2026
394d295
refactor(gdn): remove unused in_place_transpose and cutedsl_gdn_verif…
xutizhou Feb 6, 2026
b0a21ea
refactor(gdn): remove unused compare_klast_vlast_precision script
xutizhou Feb 6, 2026
a19c813
fix(gdn): remap padding indices to sentinel slot for K-last index_copy_
xutizhou Feb 6, 2026
ed7adcf
fix(gdn): use pretranspose decode kernel for K-last pool layout
xutizhou Feb 7, 2026
e57a481
perf(gdn): use pooled decode kernel to eliminate gather/scatter in K-…
xutizhou Feb 7, 2026
adcd365
perf(gdn): remove torch.where remap for K-last pooled decode, pass ra…
xutizhou Feb 7, 2026
851186f
fix(gdn): remove hardcoded FlashInfer repo path and clarify warmup te…
xutizhou Feb 7, 2026
a585826
refactor(gdn): use unified decode function with state_indices parameter
xutizhou Feb 7, 2026
91d2907
fix(gdn): use explicit k_dim/v_dim in warmup tensor layout
xutizhou Feb 7, 2026
ae24e51
fix(gdn): correct k_last shape semantics and hardcoded device
xutizhou Feb 12, 2026
4ae137b
fix(gdn): warn when K-last prefill is used with prefix caching enabled
xutizhou Feb 12, 2026
775d513
style(gdn): fix black/isort/ruff formatting issues
xutizhou Feb 12, 2026
2ad0329
style(gdn): move logger after imports, remove dead _get_flashinfer_gd…
xutizhou Feb 12, 2026
5d248e3
fix(gdn): clean up warmup error handling and remove unnecessary empty…
xutizhou Feb 12, 2026
e315ca7
fix(gdn): move logger definition before _get_flashinfer_gdn_kernels
xutizhou Feb 12, 2026
ec1c05a
fix(gdn): don't clear kernel handles on warmup failure
xutizhou Feb 12, 2026
7c20f0c
fix(gdn): ensure K-last warnings are only printed once
xutizhou Feb 12, 2026
aa21c17
merge: resolve conflicts after merging origin/main into feat/gdn_mtp
xutizhou Feb 26, 2026
603571d
Add FlashInfer GDN kernel implementation for K-last SSM layout
xutizhou Feb 26, 2026
943f2fd
Wire FlashInfer backend into GDN kernel dispatcher
xutizhou Feb 26, 2026
9352321
Optimize FlashInfer GDN kernels: fix exp precision, consolidate warmu…
xutizhou Feb 27, 2026
c66c08f
Remove gdn_benchmark skill and grpc module from tracking
xutizhou Feb 27, 2026
d6d6c37
Restore grpc module to match origin/main
xutizhou Feb 27, 2026
6e5d267
Remove redundant intermediate_state_indices pre-alloc and FlashInfer …
xutizhou Feb 27, 2026
ff22a9e
Use gather/scatter decode for stock FlashInfer compatibility
xutizhou Feb 27, 2026
ab7a51d
Add TODO/NOTE for FlashInfer upgrade path and decode pool indexing
xutizhou Feb 27, 2026
87e3bee
style(gdn): fix black formatting and trailing whitespace in gdn_flash…
xutizhou Feb 28, 2026
f9896c6
Merge remote-tracking branch 'origin/main' into feat/gdn_mtp
xutizhou Mar 1, 2026
95ac5e7
fix(gdn): remove k_last param from Mamba2StateShape to fix Nemotron CI
xutizhou Mar 1, 2026
b58871d
Merge remote-tracking branch 'origin/main' into feat/gdn_mtp
xutizhou Mar 2, 2026
4d5349d
Merge remote-tracking branch 'origin/main' into feat/gdn_mtp
xutizhou Mar 3, 2026
b0c9bd2
Merge remote-tracking branch 'origin/main' into feat/gdn_mtp
xutizhou Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions python/sglang/srt/layers/attention/linear/gdn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def __init__(
)

self.decode_kernel = CuteDSLGDNKernel()
elif decode_backend.is_flashinfer():
if not is_cuda():
raise ValueError("FlashInfer backend requires CUDA")
from sglang.srt.layers.attention.linear.kernels.gdn_flashinfer import (
FlashInferGDNKernel,
)

flashinfer_kernel = FlashInferGDNKernel()
self.decode_kernel = flashinfer_kernel
else:
raise ValueError(f"Unsupported GDN decode backend: {decode_backend}")

Expand All @@ -78,10 +87,27 @@ def __init__(
"CuTe DSL backend only supports decode, not prefill. "
"Use --linear-attn-prefill-backend triton instead."
)
elif prefill_backend.is_flashinfer():
if not is_cuda():
raise ValueError("FlashInfer backend requires CUDA")
# Reuse the FlashInfer kernel if already created for decode
if decode_backend.is_flashinfer():
self.extend_kernel = flashinfer_kernel
else:
from sglang.srt.layers.attention.linear.kernels.gdn_flashinfer import (
FlashInferGDNKernel,
)

flashinfer_kernel = FlashInferGDNKernel()
self.extend_kernel = flashinfer_kernel
else:
raise ValueError(f"Unsupported GDN prefill backend: {prefill_backend}")

self.verify_kernel = triton_kernel
# Verify kernel: use FlashInfer if either decode or prefill selected it
if decode_backend.is_flashinfer() or prefill_backend.is_flashinfer():
self.verify_kernel = flashinfer_kernel
else:
self.verify_kernel = triton_kernel

rank0_log(
f"GDN kernel dispatcher: decode={self.decode_kernel.__class__.__name__}, "
Expand Down Expand Up @@ -354,6 +380,11 @@ def forward_extend(
intermediate_state_indices=intermediate_state_indices,
cache_steps=forward_batch.spec_info.draft_token_num,
retrieve_parent_token=retrieve_parent_token,
# Pass raw pre-gating values for FlashInfer MTP kernel
a_raw=a,
b_raw=b,
A_log=layer.A_log,
dt_bias=layer.dt_bias,
)
else:
core_attn_out, last_recurrent_state, h = self.kernel_dispatcher.extend(
Expand All @@ -366,14 +397,15 @@ def forward_extend(
cache_indices=cache_indices,
query_start_loc=query_start_loc,
)
if is_npu() or is_cpu():
if (is_npu() or is_cpu()) and last_recurrent_state is not None:
last_recurrent_state = last_recurrent_state.to(
ssm_states.dtype, copy=False
)
ssm_states[cache_indices] = last_recurrent_state

self._track_mamba_state_extend(
forward_batch, h, ssm_states, forward_metadata
)
if h is not None:
self._track_mamba_state_extend(
forward_batch, h, ssm_states, forward_metadata
)

return core_attn_out
Loading
Loading