diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 6fa1bbf20874..e8b09a436567 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -14,6 +14,7 @@ MLACommonMetadata, ) from vllm.platforms.interface import DeviceCapability +from vllm.triton_utils import triton from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionLayer, @@ -115,6 +116,8 @@ def __init__( if is_quantized_kv_cache(self.kv_cache_dtype): self.supports_quant_query_input = False + self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count + def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs ): @@ -149,7 +152,24 @@ def forward_mqa( lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) # For batch invariance, use only 1 split to ensure deterministic reduction - num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4 + if envs.VLLM_BATCH_INVARIANT: + num_kv_splits = 1 + else: + # Minimum work per split + # hardware dependent + min_work_per_split = 512 + + ideal_splits = max(1, attn_metadata.max_seq_len // min_work_per_split) + + # use power of 2 to avoid excessive kernel instantiations + ideal_splits = triton.next_power_of_2(ideal_splits) + + # Calculate SM-based maximum splits with occupancy multiplier + # 2-4x allows multiple blocks per SM for latency hiding + # hardware dependent + occupancy_multiplier = 2 + max_splits = self._sm_count * occupancy_multiplier + num_kv_splits = min(ideal_splits, max_splits) # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( @@ -186,6 +206,7 @@ def forward_mqa( PAGE_SIZE, k_scale=layer._k_scale, v_scale=layer._k_scale, + is_mla=True, ) return o, lse diff --git a/vllm/v1/attention/ops/triton_decode_attention.py b/vllm/v1/attention/ops/triton_decode_attention.py index 63263bc92e24..8118db0da8cf 100644 --- a/vllm/v1/attention/ops/triton_decode_attention.py +++ b/vllm/v1/attention/ops/triton_decode_attention.py @@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1( logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, + IS_MLA: tl.constexpr = False, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) @@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1( cur_batch_req_idx = cur_batch offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + q = tl.load( + Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0, + cache_modifier=".ca", + ) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) @@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) qpe = tl.load( - Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0, + cache_modifier=".ca", ) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) @@ -331,9 +340,14 @@ def _fwd_grouped_kernel_stage1( acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + base_offs_k = cur_kv_head * stride_buf_kh + offs_d[:, None] + base_offs_v = cur_kv_head * stride_buf_vh + offs_dv[None, :] + if BLOCK_DPE > 0: + base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe[:, None] + ks = tl.load(k_scale) vs = tl.load(v_scale) - for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + for start_n in tl.range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( Req_to_tokens @@ -341,31 +355,29 @@ def _fwd_grouped_kernel_stage1( + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, + cache_modifier=".ca", ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = ( - kv_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[:, None] - ) + + # explicitly facilitate overlapping load/compute + offs_buf_k = kv_loc[None, :] * stride_buf_kbs + base_offs_k k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, + cache_modifier=".cg", ) + if k.dtype.is_fp8(): k = (k.to(tl.float32) * ks).to(q.dtype) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: - offs_buf_kpe = ( - kv_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_dpe[:, None] - ) + offs_buf_kpe = kv_loc[None, :] * stride_buf_kbs + base_offs_kpe kpe = tl.load( K_Buffer + offs_buf_kpe, mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, + cache_modifier=".cg", ) if kpe.dtype.is_fp8(): kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype) @@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1( mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") ) - offs_buf_v = ( - kv_loc[:, None] * stride_buf_vbs - + cur_kv_head * stride_buf_vh - + offs_dv[None, :] - ) - v = tl.load( - V_Buffer + offs_buf_v, - mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), - other=0.0, - ) - if v.dtype.is_fp8(): - v = (v.to(tl.float32) * vs).to(q.dtype) + if not IS_MLA: + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + base_offs_v + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + if v.dtype.is_fp8(): + v = (v.to(tl.float32) * vs).to(q.dtype) + else: + # MLA uses a single c_kv. + # loading the same c_kv to interpret it as v is not necessary. + # transpose the existing c_kv (aka k) for the dot product. + v = tl.trans(k) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd( logit_cap, k_scale, v_scale, + is_mla=False, ): + # with is_mla there is only a single c_kv in smem. + # could increase BLOCK or num_stages. BLOCK = 32 Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd( num_stages=num_stages, Lk=Lk, Lv=Lv, + IS_MLA=is_mla, **extra_kargs, ) @@ -673,6 +691,7 @@ def decode_attention_fwd_grouped( logit_cap=0.0, k_scale=None, v_scale=None, + is_mla=False, ): _decode_grouped_att_m_fwd( q, @@ -687,6 +706,7 @@ def decode_attention_fwd_grouped( logit_cap, k_scale, v_scale, + is_mla=is_mla, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -708,6 +728,7 @@ def decode_attention_fwd( logit_cap=0.0, k_scale=None, v_scale=None, + is_mla=False, ): assert num_kv_splits == attn_logits.shape[2] @@ -753,4 +774,5 @@ def decode_attention_fwd( logit_cap, k_scale, v_scale, + is_mla=is_mla, )