diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index a71153a6203..48e08d8478b 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -8,6 +8,7 @@ | **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ | | **Triton** | ❌ | ✅ | ✅ | ❌ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | +| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | ## User guide @@ -30,10 +31,15 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --trust-r ```bash python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend triton python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend triton --trust-remote-code - ``` - Torch Native ```bash python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend torch_native ``` + +- FlashMLA +```bash +python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code +python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code +``` diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 612885bc56b..6f0d9afd223 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -158,7 +158,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8 ``` - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- FlashAttention3 and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the FlashMLA backend and CutlassMLA backend is still under development. +- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index cd777841807..a6a255b3bb6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -346,7 +346,6 @@ def forward_extend( cache_loc = forward_batch.out_cache_loc logits_soft_cap = layer.logit_cap prefill_wrapper_paged = self.forward_metadata.prefill_wrapper - k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) # Save kv cache if save_kv_cache and k is not None: @@ -381,6 +380,9 @@ def forward_extend( ) else: # mla paged prefill + k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( + q.dtype + ) if q_rope is None: qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) q, q_rope = ( @@ -442,7 +444,9 @@ def forward_decode( q_nope = reshaped_q[:, :, : layer.v_head_dim] q_rope = reshaped_q[:, :, layer.v_head_dim :] - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( + q.dtype + ) o = q_nope.new_empty(q_nope.shape) # Direct call to run without the wrapper @@ -467,7 +471,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.scaling = model_runner.model_config.scaling - self.data_type = model_runner.kv_cache_dtype + self.data_type = model_runner.dtype self.attn_backend = attn_backend # Buffers and wrappers @@ -577,7 +581,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.v_head_dim = model_runner.model_config.v_head_dim self.scaling = model_runner.model_config.scaling - self.data_type = model_runner.kv_cache_dtype + self.data_type = model_runner.dtype self.q_data_type = model_runner.dtype self.attn_backend = attn_backend diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 0823239a71f..1198ddda464 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -8,7 +8,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union import torch import triton @@ -30,8 +30,8 @@ # FlashMLA only supports pagesize=64 PAGE_SIZE = 64 -# TODO The current setup is hard-coded and will be changed after integrating with MTP. -Q_LEN = 1 + +# FlashMLA FP8 issue: https://github.com/deepseek-ai/FlashMLA/issues/56 @dataclass @@ -52,7 +52,7 @@ def __init__( class FlashMLABackend(FlashInferMLAAttnBackend): - """Flashinfer attention kernels.""" + """Flashmla attention kernels.""" def __init__( self, @@ -82,42 +82,72 @@ def __init__( self.q_data_type = model_runner.dtype self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + def init_forward_metadata(self, forward_batch: ForwardBatch): bs = forward_batch.batch_size - spec_info = forward_batch.spec_info if forward_batch.forward_mode.is_decode_or_idle(): - if spec_info is None: - max_seqlen_pad = triton.cdiv( - forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE - ) - block_kv_indices = torch.full( - (bs, max_seqlen_pad), - -1, - dtype=torch.int32, - device=forward_batch.seq_lens.device, - ) - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - None, - block_kv_indices, - self.req_to_token.stride(0), - max_seqlen_pad, - ) - mla_metadata, num_splits = get_mla_metadata( - forward_batch.seq_lens.to(torch.int32), - Q_LEN * self.num_q_heads, - 1, - ) - self.forward_metadata = FlashMLADecodeMetadata( - mla_metadata, - num_splits, - block_kv_indices, - ) - else: - super().init_forward_metadata(forward_batch) + max_seqlen_pad = triton.cdiv( + forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE + ) + block_kv_indices = torch.full( + (bs, max_seqlen_pad), + -1, + dtype=torch.int32, + device=forward_batch.seq_lens.device, + ) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + ) + mla_metadata, num_splits = get_mla_metadata( + forward_batch.seq_lens.to(torch.int32), + self.num_q_heads, + 1, + ) + self.forward_metadata = FlashMLADecodeMetadata( + mla_metadata, + num_splits, + block_kv_indices, + ) + elif forward_batch.forward_mode.is_target_verify(): + seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens + seq_lens = forward_batch.seq_lens + self.num_draft_tokens + + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) + block_kv_indices = torch.full( + (bs, max_seqlen_pad), + -1, + dtype=torch.int32, + device=seq_lens.device, + ) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), + self.num_draft_tokens * self.num_q_heads, + 1, + ) + + # Use FlashMLADecodeMetadata which has the attributes forward_extend expects + self.forward_metadata = FlashMLADecodeMetadata( + mla_metadata, + num_splits, + block_kv_indices, + ) else: super().init_forward_metadata(forward_batch) @@ -136,11 +166,22 @@ def init_cuda_graph_state( else: cuda_graph_kv_indices = block_kv_indices - self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( - torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), - Q_LEN * self.num_q_heads, - 1, - ) + if self.num_draft_tokens: + self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( + torch.ones( + max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device + ), + self.num_draft_tokens * self.num_q_heads, + 1, + ) + else: + self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( + torch.ones( + max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device + ), + self.num_q_heads, + 1, + ) self.cuda_graph_kv_indices = cuda_graph_kv_indices def init_forward_metadata_capture_cuda_graph( @@ -154,31 +195,54 @@ def init_forward_metadata_capture_cuda_graph( spec_info: Optional[SpecInfo], ): if forward_mode.is_decode_or_idle(): - if spec_info is None: - max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) - - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - None, - self.cuda_graph_kv_indices, - self.req_to_token.stride(0), - self.cuda_graph_kv_indices.stride(0), - ) - mla_metadata, num_splits = get_mla_metadata( - seq_lens.to(torch.int32), - Q_LEN * self.num_q_heads, - 1, - ) - self.cuda_graph_mla_metadata.copy_(mla_metadata) - self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) - self.forward_metadata = FlashMLADecodeMetadata( - self.cuda_graph_mla_metadata, - self.cuda_graph_num_splits[: bs + 1], - self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], - ) + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), + self.num_q_heads, + 1, + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata = FlashMLADecodeMetadata( + self.cuda_graph_mla_metadata, + self.cuda_graph_num_splits[: bs + 1], + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], + ) + elif forward_mode.is_target_verify(): + seq_lens = seq_lens + self.num_draft_tokens + max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) + + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), + self.num_draft_tokens * self.num_q_heads, + 1, + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata = FlashMLADecodeMetadata( + self.cuda_graph_mla_metadata, + self.cuda_graph_num_splits[: bs + 1], + self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], + ) else: super().init_forward_metadata_capture_cuda_graph( bs, @@ -218,7 +282,32 @@ def init_forward_metadata_replay_cuda_graph( ) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), - Q_LEN * self.num_q_heads, + self.num_q_heads, + 1, + ) + self.cuda_graph_mla_metadata.copy_(mla_metadata) + self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) + self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata + self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] + self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ + :bs, :max_seqlen_pad + ] + elif forward_mode.is_target_verify(): + seq_lens = seq_lens[:bs] + self.num_draft_tokens + seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens + max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens, + None, + self.cuda_graph_kv_indices, + self.req_to_token.stride(0), + self.cuda_graph_kv_indices.stride(0), + ) + mla_metadata, num_splits = get_mla_metadata( + seq_lens.to(torch.int32), + self.num_draft_tokens * self.num_q_heads, 1, ) self.cuda_graph_mla_metadata.copy_(mla_metadata) @@ -228,7 +317,6 @@ def init_forward_metadata_replay_cuda_graph( self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ :bs, :max_seqlen_pad ] - else: super().init_forward_metadata_replay_cuda_graph( bs, @@ -268,17 +356,191 @@ def forward_decode( k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) + if self.data_type == torch.float8_e4m3fn: + reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn) + o, _ = flash_mla_with_kvcache( + q=reshape_q_fp8, + k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), + block_table=self.forward_metadata.block_kv_indices[:bs], + cache_seqlens=forward_batch.seq_lens.to(torch.int32), + head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, + softmax_scale=layer.scaling, + causal=True, + descale_q=torch.ones((1), dtype=torch.float32, device=reshape_q.device), + descale_k=torch.ones((1), dtype=torch.float32, device=reshape_q.device), + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + else: + # todo: need check all causal True or False? + o, _ = flash_mla_with_kvcache( + q=reshape_q, + k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), + block_table=self.forward_metadata.block_kv_indices[:bs], + cache_seqlens=forward_batch.seq_lens.to(torch.int32), + head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, + softmax_scale=layer.scaling, + causal=True, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + if ( + forward_batch.forward_mode == ForwardMode.EXTEND + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): + return super().forward_extend(q, k, v, layer, forward_batch, save_kv_cache) + else: + cache_loc = forward_batch.out_cache_loc + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + + bs = forward_batch.batch_size + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) + if self.data_type == torch.float8_e4m3fn: + reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn) + o, _ = flash_mla_with_kvcache( + q=reshape_q_fp8, + k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), + block_table=self.forward_metadata.block_kv_indices[:bs], + cache_seqlens=forward_batch.seq_lens.to(torch.int32) + + self.num_draft_tokens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, + softmax_scale=layer.scaling, + causal=True, + descale_q=torch.ones( + (1), dtype=torch.float32, device=reshape_q.device + ), + descale_k=torch.ones( + (1), dtype=torch.float32, device=reshape_q.device + ), + ) + else: + o, _ = flash_mla_with_kvcache( + q=reshape_q, + k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), + block_table=self.forward_metadata.block_kv_indices[:bs], + cache_seqlens=forward_batch.seq_lens.to(torch.int32) + + self.num_draft_tokens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, + num_splits=self.forward_metadata.num_splits, + softmax_scale=layer.scaling, + causal=True, + ) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + - o, _ = flash_mla_with_kvcache( - q=reshape_q, - k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), - block_table=self.forward_metadata.block_kv_indices, - cache_seqlens=forward_batch.seq_lens.to(torch.int32), - head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. - tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, - num_splits=self.forward_metadata.num_splits, - softmax_scale=layer.scaling, - causal=False, +# TODO: multi step kv indices optimization +class FlashMLAMultiStepDraftBackend: + """ + Wrap multiple flashmla attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + if topk > 1: + raise ValueError( + f"Currently FlashMLA only supports topk=1 for speculative decoding" + ) + self.topk = topk + self.speculative_num_steps = speculative_num_steps + max_bs = model_runner.req_to_token_pool.size * self.topk + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, ) - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashMLABackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + kv_last_page_len_buf=None, + ) + ) + + def common_template( + self, + forward_batch: ForwardBatch, + call_fn: Callable, + ): + assert forward_batch.spec_info is not None + + for i in range(self.speculative_num_steps - 1): + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + assert forward_batch.spec_info is not None + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, call_fn) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + ) + + self.common_template(forward_batch, call_fn) diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 816fbf08af4..71633d12dce 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -77,8 +77,8 @@ def create_flashmla_kv_indices_triton( ) * PAGED_SIZE paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK - mask = paged_offset <= num_paged * PAGED_SIZE - mask_out = paged_offset_out <= num_paged + mask = paged_offset < num_paged * PAGED_SIZE + mask_out = paged_offset_out < num_paged data = tl.load( req_to_token_ptr diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 025c75392be..e88022beb97 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -30,6 +30,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -210,7 +211,10 @@ def __init__(self, model_runner: ModelRunner): # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs - self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) + if global_server_args_dict["attention_backend"] == "flashmla": + self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) + else: + self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 7ea48102df5..08904dbfca4 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -199,6 +199,19 @@ def init_attention_backend(self): self.draft_extend_attn_backend = None self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False + elif self.server_args.attention_backend == "flashmla": + from sglang.srt.layers.attention.flashmla_backend import ( + FlashMLAMultiStepDraftBackend, + ) + + self.draft_attn_backend = FlashMLAMultiStepDraftBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + self.draft_extend_attn_backend = None + self.padded_static_len = self.speculative_num_steps + 1 + self.has_prefill_wrapper_verify = False else: raise ValueError( f"EAGLE is not supported in attention backend {self.server_args.attention_backend}" diff --git a/test/srt/test_flashmla.py b/test/srt/test_flashmla.py index b8076246513..bc17b311903 100644 --- a/test/srt/test_flashmla.py +++ b/test/srt/test_flashmla.py @@ -6,6 +6,7 @@ import unittest from types import SimpleNamespace +import requests import torch from sglang.srt.utils import kill_process_tree @@ -14,6 +15,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + CustomTestCase, is_in_ci, popen_launch_server, run_bench_one_batch, @@ -81,5 +83,71 @@ def test_latency(self): self.assertGreater(output_throughput, 100) +class TestFlashMLAMTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmsys/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--cuda-graph-max-bs", + "4", + "--disable-radix", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + "lmsys/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--attention-backend", + "flashmla", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + print(f"{server_info=}") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 1.8) + + if __name__ == "__main__": unittest.main()