Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
2 changes: 1 addition & 1 deletion docs/references/deepseek.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage)
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.

- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def init_forward_metadata_replay_cuda_graph(
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def init_forward_metadata_replay_cuda_graph(
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
Expand Down Expand Up @@ -1058,6 +1059,7 @@ def call_fn(i, forward_batch):
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)

self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
Expand Down
148 changes: 116 additions & 32 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Optional, Union

import torch
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(

# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device

global_config.enable_flashinfer_mla = True

Expand All @@ -84,10 +86,6 @@ def __init__(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)

self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)

self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
Expand Down Expand Up @@ -125,6 +123,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else:
Expand Down Expand Up @@ -160,13 +159,20 @@ def init_cuda_graph_state(
cuda_graph_kv_indices = kv_indices_buf

self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
(max_bs,), dtype=torch.int32, device=self.device
)
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
self.cuda_graph_qo_indptr = self.kv_indptr.clone()

# For fast decode plan in graph replaying
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}

def init_forward_metadata_capture_cuda_graph(
self,
Expand All @@ -182,10 +188,10 @@ def init_forward_metadata_capture_cuda_graph(
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.qo_indptr[: num_tokens + 1],
kv_indptr=self.kv_indptr[: num_tokens + 1],
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.kv_last_page_len[:num_tokens],
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
backend="auto",
)

Expand All @@ -195,9 +201,11 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")

Expand All @@ -210,13 +218,28 @@ def init_forward_metadata_replay_cuda_graph(
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)

self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
**self.fast_decode_kwargs,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
Expand Down Expand Up @@ -316,7 +339,6 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):

# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode

Expand All @@ -326,6 +348,8 @@ def update(
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
Expand All @@ -335,6 +359,8 @@ def update(
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
**fast_decode_kwargs,
)

def call_begin_forward(
Expand All @@ -345,14 +371,19 @@ def call_begin_forward(
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)

kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling

Expand All @@ -365,21 +396,36 @@ def call_begin_forward(
kv_indices,
self.req_to_token.shape[1],
)

wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
if not init_metadata_replay:
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)


class FlashInferMLAIndicesUpdaterPrefill:
Expand All @@ -399,7 +445,6 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):

# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
Expand Down Expand Up @@ -496,3 +541,42 @@ def call_begin_forward(
self.q_data_type,
self.data_type,
)


def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale

with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def init_forward_metadata_replay_cuda_graph(
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():
Expand Down Expand Up @@ -587,6 +588,7 @@ def call_fn(i, forward_batch):
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)

self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,8 +1187,13 @@ def merge_batch(self, other: "ScheduleBatch"):

def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
if global_server_args_dict["enable_flashinfer_mla"]:
decode_seq_lens = self.seq_lens.cpu()
else:
decode_seq_lens = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
Expand All @@ -1215,6 +1220,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
Expand Down Expand Up @@ -1291,6 +1297,9 @@ class ModelWorkerBatch:
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool

# For decode
decode_seq_lens: Optional[torch.Tensor]

# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def __init__(self, model_runner: ModelRunner):
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)

if self.enable_torch_compile:
set_torch_compile_config()
Expand Down Expand Up @@ -448,6 +451,10 @@ def replay(self, forward_batch: ForwardBatch):
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)

if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
Expand All @@ -466,6 +473,7 @@ def replay(self, forward_batch: ForwardBatch):
self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
)

# Replay
Expand Down
Loading
Loading