Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
98afb69
rewrite attn forward dispatch
Fridge003 Apr 4, 2025
0cbf502
prefix lengths chunking logic
Fridge003 Apr 5, 2025
c09bc8a
small rename
Fridge003 Apr 5, 2025
746fbc2
use flash_attn_varlen_func at fa3 backend
Fridge003 Apr 6, 2025
5761087
attn mha with chunk prefix forward
Fridge003 Apr 7, 2025
94dc646
latent cache gathering
Fridge003 Apr 7, 2025
f3bb2c8
add test for forward batch chunking and debug
Fridge003 Apr 8, 2025
146e7e9
add kv indices test and fix bugs
Fridge003 Apr 8, 2025
db3f16d
fix bug in mha forward
Fridge003 Apr 8, 2025
43826ce
small fix
Fridge003 Apr 9, 2025
96b8ee4
fix accuracy
ispobock Apr 10, 2025
d0bf0e8
Merge branch 'main' into chunk
Fridge003 Apr 10, 2025
2756483
Merge branch 'main' into chunk
Fridge003 Apr 11, 2025
1ef2041
fix performance bug
Fridge003 Apr 11, 2025
a344f9e
add server arg for this feature
Fridge003 Apr 11, 2025
a16b5d8
Merge branch 'main' into chunk
Fridge003 Apr 11, 2025
b125b83
enable by default
Fridge003 Apr 11, 2025
961a588
typo
Fridge003 Apr 11, 2025
420bb77
Merge branch 'main' into chunk
zhyncs Apr 11, 2025
fee6f77
add flash_attn_varlen_func to sgl-kernel and add docs
Fridge003 Apr 12, 2025
7b39dad
add test for fa3+spec+mla
Fridge003 Apr 12, 2025
49915b3
revert kernel modification
Fridge003 Apr 12, 2025
be9496c
Merge branch 'main' into chunk
Fridge003 Apr 12, 2025
74ad626
Merge branch 'main' into chunk
Fridge003 Apr 12, 2025
a79569e
Merge branch 'main' into chunk
Fridge003 Apr 13, 2025
aab2bfe
Merge branch 'main' into chunk
Fridge003 Apr 15, 2025
2462ecc
Merge branch 'main' into chunk
Fridge003 Apr 15, 2025
b8c89c9
update merge state
Fridge003 Apr 15, 2025
df02a9d
Merge branch 'main' into chunk
Fridge003 Apr 15, 2025
0f0faf2
shift to merge state v2
Fridge003 Apr 15, 2025
2ddf341
Merge branch 'main' into chunk
zhyncs Apr 15, 2025
0204bed
move merge_state_v2 to top
Fridge003 Apr 16, 2025
53f2c51
set threshold for triggering
Fridge003 Apr 16, 2025
cfc4733
fix bug
Fridge003 Apr 16, 2025
90c778e
Merge branch 'main' into chunk
zhyncs Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,4 @@ Please consult the documentation below to learn more about the parameters you ma
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.
4 changes: 3 additions & 1 deletion docs/references/deepseek.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be

- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.

- **Chunked Prefix Cache**: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend.

Overall, with these optimizations, we have achieved up to **7x** acceleration in output throughput compared to the previous version.

<p align="center">
<img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models">
</p>

**Usage**: MLA optimization is enabled by default, to disable, use `--disable-mla`.
**Usage**: MLA optimization is enabled by default. To disable MLA usage, use `--disable-mla`. To disable chunked prefix cache feature for mla, use `disable-chunked-prefix-cache`.

**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.

Expand Down
114 changes: 80 additions & 34 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner

from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache


@dataclass
Expand Down Expand Up @@ -593,41 +593,87 @@ def forward_extend(
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
if (
not global_server_args_dict["disable_chunked_prefix_cache"]
and forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache

if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None

chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0

output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=True,
)
return output, lse
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def forward_decode(
self,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
}

logger = logging.getLogger(__name__)
Expand Down
181 changes: 181 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,28 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None

# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether the kv cache needs to be attended in current pass
attn_attend_prefix_cache: Optional[bool] = None
# Number of prefix cache chunks
num_prefix_chunks: Optional[int] = None
# Index of current chunk, used by attention backend
prefix_chunk_idx: Optional[int] = None
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
prefix_chunk_len: Optional[int] = None
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_starts: Optional[torch.Tensor] = None
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
prefix_chunk_max_seq_lens: Optional[List[int]] = None
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
prefix_chunk_num_tokens: Optional[List[int]] = None
# KV Indices for each chunk
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None

# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None

Expand Down Expand Up @@ -484,6 +506,128 @@ def _compute_mrope_positions(
)
self.mrope_positions = self.mrope_positions.to(torch.int64)

def get_max_chunk_capacity(self):
# Maximum number of tokens in each chunk
# TODO: Should be changed to a better value, maybe passed through server args
return 128 * 1024

def set_prefix_chunk_idx(self, idx: int):
self.prefix_chunk_idx = idx

def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
self.attn_attend_prefix_cache = attn_attend_prefix_cache

def prepare_chunked_kv_indices(self, device: torch.device):
self.prefix_chunk_kv_indices = []
for idx in range(self.num_prefix_chunks):
chunk_starts = self.prefix_chunk_starts[idx]
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]

chunk_kv_indices = torch.empty(
num_chunk_tokens, dtype=torch.int32, device=device
)

create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)

# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
def get_prefix_chunk_seq_lens(
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
):
device = prefix_lens.device
prefix_chunk_starts = (
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, self.batch_size)
* prefix_chunk_len
)
prefix_chunk_ends = torch.min(
prefix_lens.unsqueeze(0),
prefix_chunk_starts + prefix_chunk_len,
).to(torch.int32)

prefix_chunk_seq_lens = (
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
)

return prefix_chunk_starts, prefix_chunk_seq_lens

# Called before each attention module if using chunked kv cache for prefill
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
def prepare_chunked_prefix_cache_info(self, device: torch.device):

from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool

assert isinstance(
self.token_to_kv_pool, MLATokenToKVPool
), "Currently chunked prefix cache can only be used by Deepseek models"

if self.prefix_chunk_len is not None:
# Chunked kv cache info already prepared by prior modules
return

self.prefix_chunk_idx = -1

# chunk_capacity is the maximum number of tokens in each chunk
chunk_capacity = self.get_max_chunk_capacity()
self.prefix_chunk_len = chunk_capacity // self.batch_size

self.num_prefix_chunks = (
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
) // self.prefix_chunk_len

# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
self.get_prefix_chunk_seq_lens(
self.extend_prefix_lens,
self.num_prefix_chunks,
self.prefix_chunk_len,
)
)
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
torch.tensor(self.extend_prefix_lens_cpu),
self.num_prefix_chunks,
self.prefix_chunk_len,
)
self.prefix_chunk_starts = prefix_chunk_starts_cuda
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda

# Metadata for attention backend
self.prefix_chunk_cu_seq_lens = torch.zeros(
self.num_prefix_chunks,
self.batch_size + 1,
device=device,
dtype=torch.int32,
)
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
dim=1
).to(torch.int32)
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
dim=1
).values.tolist()

self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()

# Precompute the kv indices for each chunk
self.prepare_chunked_kv_indices(device)


def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
Expand Down Expand Up @@ -561,3 +705,40 @@ def compute_position_torch(
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)


@triton.jit
def create_chunked_prefix_cache_kv_indices(
req_to_token_ptr, # (max_batch, max_context_len,)
req_pool_indices_ptr, # (batch_size,)
chunk_start_idx_ptr, # (batch_size,)
chunk_seq_lens_ptr, # (batch_size,)
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
chunk_kv_indices_ptr, # (num_chunk_tokens,)
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)

# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid)
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)

# get the token positions of current chunk
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)

num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < chunk_seq_len
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ chunk_start_pos
+ offset,
mask=mask,
)
tl.store(
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
)
11 changes: 11 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend,
}
)
Expand Down Expand Up @@ -318,6 +319,16 @@ def model_specific_adjustment(self):
if server_args.enable_deepep_moe:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")

if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True

if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.")

def init_torch_distributed(self):
logger.info("Init torch distributed begin.")

Expand Down
Loading
Loading