Skip to content
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
21aea66
init draft
yyihuang Apr 20, 2025
9573c49
upd
yyihuang Apr 20, 2025
527662a
upd
yyihuang Apr 21, 2025
5b22d19
upd
yyihuang Apr 21, 2025
0b6b114
upd
yyihuang Apr 21, 2025
120d57e
fmt
yyihuang Apr 21, 2025
c6b157f
upd
yyihuang Apr 21, 2025
a4623aa
add ci (todo: cuda graph shape error)
yyihuang Apr 21, 2025
df8a324
upd disable cuda graph
yyihuang Apr 21, 2025
25c5b6c
add print
yyihuang Apr 21, 2025
06db9d0
kv fp8 only for flashinfer mla
neiltian-tencent Apr 19, 2025
4e723d5
fix extend for main and draft model
neiltian-tencent Apr 22, 2025
e925426
fix flashmla bug (#5272)
sleepcoo Apr 22, 2025
6b62696
target_verify use flashinfer_mla, no cudagraph, result ok
quinnrong94 Apr 24, 2025
5f6e167
target_verify user flashmla, precision is low
quinnrong94 Apr 25, 2025
011eff1
add flashmla fp8
Apr 15, 2025
4acdb6a
update for conflict
neiltian-tencent Apr 27, 2025
f4b0265
flash mla decode fp8
neiltian-tencent Apr 27, 2025
3d111cd
flashmla backend support mtp cuda graph
Apr 27, 2025
72b96c2
fix block_kv_indices cuda graph in mtp decode
Apr 27, 2025
3ef7731
fix flash_mla seq_lens error
quinnrong94 Apr 25, 2025
3c94f0e
fix MTP + FlashMLA seq_len bug
mpjlu Apr 27, 2025
dcc7d17
fix multi draft crash
neiltian-tencent Apr 28, 2025
ab91da0
fix mutli-batch flashmla error
quinnrong94 Apr 28, 2025
87e58a4
protect for none type
neiltian-tencent Apr 30, 2025
75c7637
lerge branch 'main' into mla_spec_dev
neiltian-tencent May 6, 2025
634e033
remove debug info
neiltian-tencent May 7, 2025
0332618
remove flashmla backend unused
neiltian-tencent May 7, 2025
8244395
update remove todo
neiltian-tencent May 7, 2025
8addc4c
remove debug info
quinnrong94 May 8, 2025
f0d160d
fix flasinfer mla kv cache dtype
quinnrong94 May 8, 2025
585737f
clean code
quinnrong94 May 8, 2025
802750f
fix type check error
quinnrong94 May 8, 2025
1a45d06
fix some merge error
quinnrong94 May 8, 2025
231e7c9
remove hardcode Q_LEN
quinnrong94 May 8, 2025
81dc6ab
update
quinnrong94 May 8, 2025
922694a
format code
neiltian-tencent May 8, 2025
5e4846a
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 8, 2025
080438c
refactor and fix flashmla mtp test
neiltian-tencent May 9, 2025
3121262
refactor for judge attention backend flashmla
neiltian-tencent May 9, 2025
d79776c
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 9, 2025
8676e72
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 9, 2025
e6dc23f
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 10, 2025
e6658e5
Merge branch 'main' into mla_spec_dev
sleepcoo May 10, 2025
83af17e
fix Qwen/Qwen2.5-VL-3B-Instruct timeout 16470 > 16000
neiltian-tencent May 10, 2025
3571d0b
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 11, 2025
c1dc2c2
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 11, 2025
4c2cb13
Merge branch 'main' into mla_spec_dev
Fridge003 May 11, 2025
e0973d8
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 12, 2025
8b65fa4
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 12, 2025
2c8dd88
Merge branch 'main' into mla_spec_dev
sleepcoo May 12, 2025
f7227c9
Merge branch 'main' into mla_spec_dev
sleepcoo May 12, 2025
a50525f
Merge branch 'main' into mla_spec_dev
neiltian-tencent May 13, 2025
60cd84f
remove unused page size test
neiltian-tencent May 13, 2025
4fc17bd
Merge branch 'main' into mla_spec_dev
sleepcoo May 13, 2025
0734a19
update doc for flashmla mtp and kv fp8
neiltian-tencent May 13, 2025
6870222
Merge branch 'main' into mla_spec_dev
sleepcoo May 14, 2025
800584e
update test for flashmla 112
neiltian-tencent May 14, 2025
5c08818
Merge branch 'main' into mla_spec_dev
quinnrong94 May 14, 2025
dccfd40
update avg_spec_accept_length
quinnrong94 May 14, 2025
0259eb6
fix
quinnrong94 May 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def forward_extend(
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)

# Save kv cache
if save_kv_cache and k is not None:
Expand Down Expand Up @@ -381,6 +380,9 @@ def forward_extend(
)
else:
# mla paged prefill
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = (
Expand Down Expand Up @@ -442,7 +444,9 @@ def forward_decode(
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]

k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)

o = q_nope.new_empty(q_nope.shape)
# Direct call to run without the wrapper
Expand All @@ -467,7 +471,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.data_type = model_runner.dtype
self.attn_backend = attn_backend

# Buffers and wrappers
Expand Down Expand Up @@ -577,7 +581,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.data_type = model_runner.dtype
self.q_data_type = model_runner.dtype
self.attn_backend = attn_backend

Expand Down
Loading
Loading