diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py old mode 100644 new mode 100755 index ea611848b0e8..8f2af10754d0 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -14,32 +14,35 @@ MultipleOf, ) from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv -from vllm.utils.platform_utils import get_cu_count from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_prefills_and_extends, + split_decodes_prefills_and_extends, ) from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 if current_platform.is_rocm(): import aiter + from aiter.ops.triton.utils.device_info import get_num_sms from vllm.triton_utils import tl, triton def block_size(x, head_dim): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) - def num_programs(total_tokens): - return min(total_tokens, get_cu_count()) + def num_programs(head_dim): + return min(head_dim, get_num_sms()) @triton.jit def cp_mha_gather_cache_kernel( @@ -58,11 +61,11 @@ def cp_mha_gather_cache_kernel( x, max_block_num, num_tokens, - num_programs, DEQUANT: tl.constexpr, PAGE_SIZE: tl.constexpr, CACHE_FORMAT: tl.constexpr, BLOCK_SIZE: tl.constexpr, + NUM_PRGMS: tl.constexpr, ): bid = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) @@ -70,7 +73,7 @@ def cp_mha_gather_cache_kernel( k_scale = tl.load(k_scale_ptr) v_scale = tl.load(v_scale_ptr) - for token_id in tl.range(bid, num_tokens, num_programs): + for token_id in tl.range(bid, num_tokens, NUM_PRGMS): key_ptr_offset = key_ptr + token_id * head_size * num_heads value_ptr_offset = value_ptr + token_id * head_size * num_heads batch_idx = tl.load(token_to_batch_ptr + token_id) @@ -162,17 +165,55 @@ def cp_mha_gather_cache( x, block_tables.size(1), total_tokens, - NUM_PRGMS, DEQUANT=dequant, PAGE_SIZE=page_size, CACHE_FORMAT=kv_cache_layout, BLOCK_SIZE=BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS, ) logger = init_logger(__name__) +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata + + @dataclass class AiterFlashAttentionDecodeMetadata: max_query_len: int @@ -242,6 +283,18 @@ class AiterFlashAttentionMetadata: prefill_metadata: AiterFlashAttentionPrefillMetadata | None extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None + # prefill and deocde split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_extends: int + num_extend_tokens: int + + decode_metadata: AiterFlashAttentionDecodeMetadata | None + prefill_metadata: AiterFlashAttentionPrefillMetadata | None + extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -251,7 +304,7 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata] ): - _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: int = 1 def __init__( @@ -284,6 +337,12 @@ def __init__( device=device, ) + self.extend_workspace = torch.empty( + [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim], + dtype=self.model_config.dtype, + device=device, + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ): @@ -321,6 +380,103 @@ def build( query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], + ) + + prefill_metadata = None + if num_prefills > 0: + query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes + num_extends : + ] + prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_prefill.max().item(), + min_query_len=query_lens_for_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + ) + + extend_metadata = None + if num_extends > 0: + num_extends_slice = slice(num_decodes, num_decodes + num_extends) + query_lens_for_extend = query_lens_cpu[num_extends_slice] + seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice] + computed_kv_lens = seq_lens_for_extend - query_lens_for_extend + + # allocate the equal amount of workspace for + # each chunk prefill request + max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_extends) + * max_context_chunk + ) + chunk_ends = torch.min( + computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0 + ) # [num_chunks, num_extends] + cu_seq_lens_cpu = torch.zeros( + [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :] + idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None] + idx_to_batch_tensor = idx_to_batch_tensor.sum( + dim=1 + ) # [num_chunks, max_cum_tokens] + token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1) + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.extend_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(), + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes : num_decodes + num_extends + 1 + ] + seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice] + split_ret = split_decodes_prefills_and_extends( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + ) + + ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + num_extend_tokens, + num_prefill_tokens, + ) = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + seq_lens = common_attn_metadata.seq_lens_cpu + + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + decode_metadata = None if num_decodes > 0: decode_metadata = AiterFlashAttentionDecodeMetadata( @@ -400,7 +556,20 @@ def build( seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice] cu_seq_lens = torch.zeros( num_extends + 1, dtype=torch.int32, device=seq_lens_device.device + num_extends + 1, dtype=torch.int32, device=seq_lens_device.device + ) + torch.cumsum( + seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:] ) + extend_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_extend.max().item(), + min_query_len=query_lens_for_extend.min().item(), + max_seq_len=seq_lens[num_extends_slice].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata, + ) + + num_actual_kv_tokens = torch.sum(seq_lens).item() torch.cumsum( seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:] ) @@ -417,6 +586,7 @@ def build( use_cascade = common_prefix_len > 0 attn_metadata = AiterFlashAttentionMetadata( + num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, max_query_len=common_attn_metadata.max_query_len, @@ -434,6 +604,21 @@ def build( decode_metadata=decode_metadata, prefill_metadata=prefill_metadata, extend_metadata=extend_metadata, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_extends=num_extends, + num_extend_tokens=num_extend_tokens, + decode_metadata=decode_metadata, + prefill_metadata=prefill_metadata, + extend_metadata=extend_metadata, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, @@ -476,6 +661,7 @@ def get_kv_cache_shape( if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -618,6 +804,111 @@ def extend_forward( chunked_output = tmp_output chunked_lse = tmp_lse + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + ) + + def extend_forward( + self, + attn_metadata: AiterFlashAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: float, + v_scale: float, + ): + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + assert attn_metadata.extend_metadata is not None + chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched = workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=False, + kv_cache_layout="NHD", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse, + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + merge_attn_states( output=output, prefix_output=chunked_output, @@ -710,6 +1001,19 @@ def forward( num_prefills = attn_metadata.num_prefills num_extends = attn_metadata.num_extends + num_decode_tokens = attn_metadata.num_decode_tokens + num_extend_tokens = attn_metadata.num_extend_tokens + # decode:extend:prefill + query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] + + output_actual_tokens = output[:num_actual_tokens] + + num_decodes = attn_metadata.num_decodes + num_prefills = attn_metadata.num_prefills + num_extends = attn_metadata.num_extends + num_decode_tokens = attn_metadata.num_decode_tokens num_extend_tokens = attn_metadata.num_extend_tokens if not attn_metadata.use_cascade: @@ -771,6 +1075,15 @@ def forward( ) # calculate for decodes + if num_decodes > 0: + assert attn_metadata.decode_metadata is not None + _, num_heads, head_size = query.shape + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + num_seqs = attn_metadata.seq_lens.shape[0] + max_num_partitions = ( + attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + # calculate for decodes if num_decodes > 0: assert attn_metadata.decode_metadata is not None _, num_heads, head_size = query.shape @@ -780,6 +1093,13 @@ def forward( attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 ) // _PARTITION_SIZE_ROCM + workspace_buffer = torch.empty( + (num_seqs * num_heads * max_num_partitions * head_size) + * nbytes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) workspace_buffer = torch.empty( (num_seqs * num_heads * max_num_partitions * head_size) * nbytes_per_qo_elem @@ -807,6 +1127,7 @@ def forward( layer._v_scale, None, _PARTITION_SIZE_ROCM, + sliding_window=self.sliding_window[0] + 1, ) else: raise NotImplementedError(