Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
2 changes: 1 addition & 1 deletion docs/references/deepseek.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage)

- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def init_forward_metadata_capture_cuda_graph(
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
**kwargs,
spec_info: Optional[SpecInfo],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
Expand All @@ -41,8 +42,9 @@ def init_forward_metadata_replay_cuda_graph(
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
**kwargs,
spec_info: Optional[SpecInfo],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,9 @@ def init_forward_metadata_capture_cuda_graph(
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
Expand Down Expand Up @@ -340,10 +339,9 @@ def init_forward_metadata_replay_cuda_graph(
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
Expand Down
156 changes: 38 additions & 118 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Optional, Union

import torch
Expand All @@ -28,12 +27,14 @@
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo

if is_flashinfer_available():
from flashinfer import (
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state


@dataclass
Expand Down Expand Up @@ -62,7 +63,6 @@ def __init__(

# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device

global_config.enable_flashinfer_mla = True

Expand All @@ -85,6 +85,10 @@ def __init__(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)

self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)

self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
Expand Down Expand Up @@ -122,7 +126,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else:
Expand Down Expand Up @@ -158,38 +161,32 @@ def init_cuda_graph_state(
cuda_graph_kv_indices = kv_indices_buf

self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
(max_bs,), dtype=torch.int32, device=self.device
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)

# For fast decode plan in graph replaying
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
self.cuda_graph_qo_indptr = self.kv_indptr.clone()

def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
**kwargs,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
qo_indptr=self.qo_indptr[: num_tokens + 1],
kv_indptr=self.kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
kv_len_arr=self.kv_last_page_len[:num_tokens],
backend="auto",
)

Expand All @@ -199,11 +196,9 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")

Expand All @@ -213,30 +208,16 @@ def init_forward_metadata_replay_cuda_graph(
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
seq_lens_cpu: torch.Tensor,
**kwargs,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)

self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
**self.fast_decode_kwargs,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
Expand Down Expand Up @@ -336,6 +317,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):

# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode

Expand All @@ -345,8 +327,6 @@ def update(
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
Expand All @@ -356,8 +336,6 @@ def update(
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
**fast_decode_kwargs,
)

def call_begin_forward(
Expand All @@ -368,19 +346,14 @@ def call_begin_forward(
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)

kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling

Expand All @@ -393,36 +366,21 @@ def call_begin_forward(
kv_indices,
self.req_to_token.shape[1],
)
if not init_metadata_replay:
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)

wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)


class FlashInferMLAIndicesUpdaterPrefill:
Expand All @@ -442,6 +400,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):

# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
Expand Down Expand Up @@ -538,42 +497,3 @@ def call_begin_forward(
self.q_data_type,
self.data_type,
)


def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale

with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)
6 changes: 2 additions & 4 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,9 @@ def init_forward_metadata_capture_cuda_graph(
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
assert encoder_lens is None, "Not supported"

Expand Down Expand Up @@ -309,10 +308,9 @@ def init_forward_metadata_replay_cuda_graph(
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():
Expand Down
Loading
Loading