diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index bc9f6ba980..3dc273accf 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -122,6 +122,116 @@ def pa_fwd_asm( ) -> torch.Tensor: ... +def _should_use_asm_kernel( + num_seqs: int, + num_heads: int, + kv_cache_tensor_dtype: torch.dtype, +) -> bool: + + if kv_cache_tensor_dtype == torch.int8: + return True + + # Get GPU compute units (CUs) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + # ASM kernel becomes relevant, once the total_heads is sufficiently large compared to CUs + total_heads = num_seqs * num_heads + return total_heads > 2 * cu_num + + +def paged_attention_common( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + exp_sums: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_tables_stride0: int, + scale: float, + max_qlen: int = 1, + max_seq_len: int = 1, + K_QScale_hip: Optional[torch.Tensor] = None, # [num_seqs, num_heads] + V_QScale_hip: Optional[torch.Tensor] = None, + K_QScale_asm: Optional[ + torch.Tensor + ] = None, # [num_blocks, num_kv_heads, block_size] + V_QScale_asm: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + qo_indptr: Optional[torch.Tensor] = None, + high_precision: Optional[ + int + ] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache + kernelName: Optional[str] = None, + kv_cache_dtype: str = "auto", + kv_cache_tensor_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Paged attention forward pass with automatic kernel selection. + ASM is favored for int8 kv caches, for short ctx_len, or when the workload exceeds + the heuristic thresholds for larger ctx_len values. + PA is normally using per tensor quant and this is what has been tested, however, + per head quant can be supported as well in principle, but not tested. + """ + kv_cache_tensor_dtype = ( + kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else K.dtype + ) + num_seqs, num_heads, head_size = Q.shape + + use_asm_kernel = ( + _should_use_asm_kernel(num_seqs, num_heads, kv_cache_tensor_dtype) + or high_precision == 2 + ) + + if use_asm_kernel: + output = pa_fwd_asm( + Q, + K, + V, + block_tables, + context_lens, + block_tables_stride0, + max_qlen, + K_QScale_asm, + V_QScale_asm, + out_, + qo_indptr, + high_precision, + kernelName, + ) + return output + + # Use ROCm paged attention kernel for smaller workloads / common path. + output = out_ if out_ is not None else torch.empty_like(Q) + + paged_attention_rocm( + out=output, + exp_sums=exp_sums, + max_logits=max_logits, + tmp_out=tmp_out, + query=Q, + key_cache=K, + value_cache=V, + num_kv_heads=int(K.size(1)), + scale=scale, + block_tables=block_tables, + context_lens=context_lens, + block_size=int(K.size(3)), + max_context_len=max_seq_len, + alibi_slopes=None, + kv_cache_dtype=kv_cache_dtype, + k_scale=K_QScale_hip, + v_scale=V_QScale_hip, + fp8_out_scale=None, + partition_size=256, + mtp=1, + q_scale=None, + ) + return output + + def gen_pa_ps_fwd_asm( Q: torch.Tensor, K: torch.Tensor, diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index c9ae6890db..7c218a29ae 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -15,6 +15,8 @@ benchmark, ) from aiter import pertoken_quant +from aiter.ops import attention + import argparse import pandas as pd @@ -404,6 +406,110 @@ def run_aiter_asm( ) +@perftest() +def run_aiter_common( + query, + k_cache, + v_cache, + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables_stride0, + # ROCm/HIP (scalar) scales + k_scale_hip=None, + v_scale_hip=None, + # ASM (expanded) scales + k_scale_asm=None, + v_scale_asm=None, + high_precision=0, + kv_cache_tensor_dtype=None, +): + """ + Test paged_attention_common which automatically switches between ASM and HIP kernels. + """ + + num_seqs, num_heads, head_size = query.shape + # Client-side allocations required by ROCm paged attention path. + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + tmp_out = torch.empty( + (num_seqs, num_heads, max_num_partitions, head_size), + dtype=query.dtype, + device=query.device, + ) + exp_sums = torch.empty( + (num_seqs, num_heads, max_num_partitions), + dtype=dtypes.fp32, + device=query.device, + ) + max_logits = torch.empty_like(exp_sums) + + def _normalize_scale(s): + if s is None: + return None + if isinstance(s, torch.Tensor): + return s.to(device=query.device, dtype=dtypes.fp32) + # python scalar + return torch.tensor(float(s), device=query.device, dtype=dtypes.fp32) + + k_scale_hip_tensor = _normalize_scale(k_scale_hip) + v_scale_hip_tensor = _normalize_scale(v_scale_hip) + # ASM scales are already tensors in the expected layout; just ensure fp32 on device. + k_scale_asm_tensor = _normalize_scale(k_scale_asm) + v_scale_asm_tensor = _normalize_scale(v_scale_asm) + + # Determine kv_cache_dtype string. + def _is_fp8_storage(dt: torch.dtype) -> bool: + if dt == torch.int8 or dt == torch.uint8: + return True + # torch float8 dtypes (guard for older torch builds) + for name in ( + "float8_e4m3fnuz", + "float8_e4m3fn", + "float8_e5m2fnuz", + "float8_e5m2", + ): + if hasattr(torch, name) and dt == getattr(torch, name): + return True + return False + + cache_dt = ( + kv_cache_tensor_dtype if kv_cache_tensor_dtype is not None else k_cache.dtype + ) + kv_cache_dtype_str = "fp8" if _is_fp8_storage(cache_dt) else "auto" + + return attention.paged_attention_common( + Q=query.contiguous(), + K=k_cache, + V=v_cache, + exp_sums=exp_sums, + max_logits=max_logits, + tmp_out=tmp_out, + block_tables=block_tables, + context_lens=seq_lens, + block_tables_stride0=block_tables_stride0, + scale=scale, + max_qlen=1, + max_seq_len=max_seq_len, + K_QScale_hip=k_scale_hip_tensor, + V_QScale_hip=v_scale_hip_tensor, + K_QScale_asm=k_scale_asm_tensor, + V_QScale_asm=v_scale_asm_tensor, + out_=None, + qo_indptr=None, + high_precision=high_precision, + kernelName=None, + kv_cache_dtype=kv_cache_dtype_str, + kv_cache_tensor_dtype=kv_cache_tensor_dtype, + ) + + def dump_input( path, query: torch.Tensor, @@ -587,6 +693,32 @@ def test_paged_attention( ) # tensor_dump(out_aiter, 'out_aiter') + # Test paged_attention_common which automatically switches between ASM and HIP + # The routing is internal, so we just test the common API regardless of which path it takes + time_aiter_common = None + if dtype == dtypes.bf16: + try: + out_aiter_common, time_aiter_common = run_aiter_common( + query.contiguous(), + k_cache, + asm_V_shuffle(v_cache), # Shuffle V cache, same as run_aiter_asm + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables.stride(0), + ) + checkAllclose( + out_golden, + out_aiter_common, + msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......", + ) + except Exception as e: + print(f"Warning: Could not test aiter_common: {e}") + for quant_algo_, cache_type_ in [ (0, k_cache.dtype), (2, dtypes.fp8), @@ -705,6 +837,31 @@ def test_paged_attention( msg=f"golden vs aiter_asm:{time_aiter_asm:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", ) + if quant_algo_ == 4: + # Test paged_attention_common with quantized cache + out_aiter_common, time_aiter_common = run_aiter_common( + query.contiguous(), + k_quant_, + asm_V_shuffle(v_quant_), + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + block_tables.stride(0), + k_scale_hip=k_scale_, + v_scale_hip=v_scale_, + k_scale_asm=k_scale_asm, + v_scale_asm=v_scale_asm, + ) + checkAllclose( + out_golden, + out_aiter_common, + msg=f"golden vs aiter_common:{time_aiter_common:>8.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})", + ) + if ( dtype in [dtypes.bf16, dtypes.fp16] and quant_algo_ == 2 @@ -809,7 +966,11 @@ def test_paged_attention( print( f"finish~ {ctx_lens=}, {num_seqs=}, {num_heads=}, {head_size=}, {use_alibi=}, {block_size=}, {dtype=}, {kv_cache_dtype=}\n" ) - return {"aiter_shomy": time_aiter, "aiter_asm": time_aiter_asm} + return { + "aiter_shomy": time_aiter, + "aiter_asm": time_aiter_asm, + "aiter_common": time_aiter_common, + } df = []