Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MLACommonMetadata,
)
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import triton
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionLayer,
Expand Down Expand Up @@ -115,6 +116,8 @@ def __init__(
if is_quantized_kv_cache(self.kv_cache_dtype):
self.supports_quant_query_input = False

self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count

def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
Expand Down Expand Up @@ -149,7 +152,24 @@ def forward_mqa(
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)

# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
if envs.VLLM_BATCH_INVARIANT:
num_kv_splits = 1
else:
# Minimum work per split
# hardware dependent
min_work_per_split = 512

ideal_splits = max(1, attn_metadata.max_seq_len // min_work_per_split)

# use power of 2 to avoid excessive kernel instantiations
ideal_splits = triton.next_power_of_2(ideal_splits)

# Calculate SM-based maximum splits with occupancy multiplier
# 2-4x allows multiple blocks per SM for latency hiding
# hardware dependent
occupancy_multiplier = 2
max_splits = self._sm_count * occupancy_multiplier
num_kv_splits = min(ideal_splits, max_splits)

# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
Expand Down Expand Up @@ -186,6 +206,7 @@ def forward_mqa(
PAGE_SIZE,
k_scale=layer._k_scale,
v_scale=layer._k_scale,
is_mla=True,
)

return o, lse
72 changes: 47 additions & 25 deletions vllm/v1/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
IS_MLA: tl.constexpr = False,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
Expand All @@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_req_idx = cur_batch

offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
q = tl.load(
Q + offs_q,
mask=(mask_h[:, None]) & (mask_d[None, :]),
other=0.0,
cache_modifier=".ca",
)

if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
Expand All @@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1(
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
Q + off_qpe,
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
other=0.0,
cache_modifier=".ca",
)

kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
Expand All @@ -331,41 +340,44 @@ def _fwd_grouped_kernel_stage1(
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)

if split_kv_end > split_kv_start:
base_offs_k = cur_kv_head * stride_buf_kh + offs_d[:, None]
base_offs_v = cur_kv_head * stride_buf_vh + offs_dv[None, :]
if BLOCK_DPE > 0:
base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe[:, None]

ks = tl.load(k_scale)
vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
for start_n in tl.range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens
+ stride_req_to_tokens_b * cur_batch_req_idx
+ offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
cache_modifier=".ca",
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)

# explicitly facilitate overlapping load/compute
offs_buf_k = kv_loc[None, :] * stride_buf_kbs + base_offs_k
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
cache_modifier=".cg",
)

if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
offs_buf_kpe = kv_loc[None, :] * stride_buf_kbs + base_offs_kpe
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
cache_modifier=".cg",
)
if kpe.dtype.is_fp8():
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
Expand All @@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)

offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
if not IS_MLA:
offs_buf_v = kv_loc[:, None] * stride_buf_vbs + base_offs_v
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
if v.dtype.is_fp8():
v = (v.to(tl.float32) * vs).to(q.dtype)
else:
# MLA uses a single c_kv.
# loading the same c_kv to interpret it as v is not necessary.
# transpose the existing c_kv (aka k) for the dot product.
v = tl.trans(k)

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd(
logit_cap,
k_scale,
v_scale,
is_mla=False,
):
# with is_mla there is only a single c_kv in smem.
# could increase BLOCK or num_stages.
BLOCK = 32
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
Expand Down Expand Up @@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd(
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
IS_MLA=is_mla,
**extra_kargs,
)

Expand Down Expand Up @@ -673,6 +691,7 @@ def decode_attention_fwd_grouped(
logit_cap=0.0,
k_scale=None,
v_scale=None,
is_mla=False,
):
_decode_grouped_att_m_fwd(
q,
Expand All @@ -687,6 +706,7 @@ def decode_attention_fwd_grouped(
logit_cap,
k_scale,
v_scale,
is_mla=is_mla,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
Expand All @@ -708,6 +728,7 @@ def decode_attention_fwd(
logit_cap=0.0,
k_scale=None,
v_scale=None,
is_mla=False,
):
assert num_kv_splits == attn_logits.shape[2]

Expand Down Expand Up @@ -753,4 +774,5 @@ def decode_attention_fwd(
logit_cap,
k_scale,
v_scale,
is_mla=is_mla,
)
Loading