Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions python/sglang/srt/layers/attention/flashmla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode

if TYPE_CHECKING:
Expand Down Expand Up @@ -75,6 +76,11 @@ def __init__(
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
# Check if KV cache is FP8 (supports both e4m3 and e5m2)
self.is_fp8_kvcache = self.data_type in {
torch.float8_e4m3fn,
torch.float8_e5m2,
}

self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens

Expand Down Expand Up @@ -104,6 +110,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
Expand Down Expand Up @@ -134,6 +141,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)

# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
Expand Down Expand Up @@ -168,6 +176,7 @@ def init_cuda_graph_state(
),
self.num_draft_tokens * self.num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
else:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
Expand All @@ -176,6 +185,7 @@ def init_cuda_graph_state(
),
self.num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices

Expand Down Expand Up @@ -206,6 +216,7 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens.to(torch.int32),
num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
Expand All @@ -231,6 +242,7 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
Expand Down Expand Up @@ -281,6 +293,7 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens.to(torch.int32),
num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
Expand All @@ -306,6 +319,7 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
Expand Down Expand Up @@ -353,8 +367,28 @@ def forward_decode(
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)

reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
if self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
if self.is_fp8_kvcache:
# For FP8 KV cache, Q needs to be converted to FP8 for FlashMLA kernel
# In SGLang, we use layer.k_scale for both q and k scales
if layer.k_scale is not None:
q_scale = layer.k_scale
descale_q = layer.k_scale.reshape(1)
descale_k = layer.k_scale.reshape(1)
else:
# Fallback to 1.0 if k_scale is not initialized
q_scale = torch.ones((1,), dtype=torch.float32, device=reshape_q.device)
descale_q = torch.ones(
(1,), dtype=torch.float32, device=reshape_q.device
)
descale_k = torch.ones(
(1,), dtype=torch.float32, device=reshape_q.device
)

# Reshape to 2D for scaled_fp8_quant (which requires 2D input)
q_shape = reshape_q.shape
reshape_q_2d = reshape_q.reshape(-1, q_shape[-1])
reshape_q_fp8_2d, _ = scaled_fp8_quant(reshape_q_2d, q_scale)
reshape_q_fp8 = reshape_q_fp8_2d.reshape(q_shape)
o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
Expand All @@ -365,8 +399,8 @@ def forward_decode(
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
descale_q=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
descale_k=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
descale_q=descale_q,
descale_k=descale_k,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
Expand Down Expand Up @@ -412,8 +446,31 @@ def forward_extend(
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)

reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
if self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
if self.is_fp8_kvcache:
# For FP8 KV cache, Q needs to be converted to FP8 for FlashMLA kernel
# In SGLang, we use layer.k_scale for both q and k scales
if layer.k_scale is not None:
q_scale = layer.k_scale
descale_q = layer.k_scale.reshape(1)
descale_k = layer.k_scale.reshape(1)
else:
# Fallback to 1.0 if k_scale is not initialized
q_scale = torch.ones(
(1,), dtype=torch.float32, device=reshape_q.device
)
descale_q = torch.ones(
(1,), dtype=torch.float32, device=reshape_q.device
)
descale_k = torch.ones(
(1,), dtype=torch.float32, device=reshape_q.device
)

# Quantize Q using scaled_fp8_quant (matching vLLM's approach)
# Reshape to 2D for scaled_fp8_quant (which requires 2D input)
q_shape = reshape_q.shape
reshape_q_2d = reshape_q.reshape(-1, q_shape[-1])
reshape_q_fp8_2d, _ = scaled_fp8_quant(reshape_q_2d, q_scale)
reshape_q_fp8 = reshape_q_fp8_2d.reshape(q_shape)
o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
Expand All @@ -425,12 +482,8 @@ def forward_extend(
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
descale_q=torch.ones(
(1), dtype=torch.float32, device=reshape_q.device
),
descale_k=torch.ones(
(1), dtype=torch.float32, device=reshape_q.device
),
descale_q=descale_q,
descale_k=descale_k,
)
else:
o, _ = flash_mla_with_kvcache(
Expand Down
Loading