Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
* `flashinfer_mla_disable_ragged`: Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing the non-prefix attention before saving the new KV cache into the page table. Only use it when FlashInfer is being used as the MLA backend.
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.
27 changes: 17 additions & 10 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,18 +425,25 @@ def forward_extend(
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)

if self.forward_metadata.extend_no_prefix:
o = o1
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def forward_extend(

if self.forward_metadata.use_ragged:
# ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
o = self.prefill_wrapper_ragged.forward(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
Expand Down
Loading