diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 560b4deb24d..1dd9ab0074a 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. +* `flashinfer_mla_disable_ragged`: Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. * `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 2cb91f0947b..c2927510350 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -425,18 +425,25 @@ def forward_extend( v_scale=v_scale, ) else: - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) - if self.forward_metadata.extend_no_prefix: - o = o1 + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale, + ) else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) o2, s2 = prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 81afcb9dac5..a43dd0f86d7 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -348,7 +348,7 @@ def forward_extend( if self.forward_metadata.use_ragged: # ragged prefill - o, _ = self.prefill_wrapper_ragged.forward_return_lse( + o = self.prefill_wrapper_ragged.forward( qall, k.view(-1, layer.tp_k_head_num, layer.head_dim), v.view(-1, layer.tp_k_head_num, layer.v_head_dim),