diff --git a/iree/turbine/kernel/wave/templates/paged_decode_attention.py b/iree/turbine/kernel/wave/templates/paged_decode_attention.py index 24dfca42d..be5a26ff1 100644 --- a/iree/turbine/kernel/wave/templates/paged_decode_attention.py +++ b/iree/turbine/kernel/wave/templates/paged_decode_attention.py @@ -54,11 +54,15 @@ def get_paged_decode_attention_kernels( input_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16, layer_scaling: Optional[float] = None, - mha: bool = False, logit_cap: float = 0.0, ): - if mha: - assert shape.num_query_heads == shape.num_kv_heads + """ + Supports multi-head attention (MHA), multi-query attention (MQA), and + grouped-query attention (GQA) depending on the number of query heads + compared to the number of key-value heads. + """ + + multi_head_attention = shape.num_query_heads == shape.num_kv_heads wave_input_dtype = torch_dtype_to_wave(input_dtype) wave_output_dtype = torch_dtype_to_wave(output_dtype) @@ -76,7 +80,7 @@ def get_paged_decode_attention_kernels( SPLIT_LEN = tkl.sym.SPLIT_LEN SPLITS_ACTIVE = tkl.sym.SPLITS_ACTIVE U = tkl.sym.U # Num splits - if mha: + if multi_head_attention: BH = B else: BH = tkl.sym.BH @@ -100,7 +104,7 @@ class Phase(Enum): PHASE_1_BLOCK_N = 16 head_ratio = shape.num_query_heads // shape.num_kv_heads MMA_VEC_SIZE = 16 # TODO: Actual value depends in mma type - if mha: + if multi_head_attention: B_WAVES = 1 else: B_WAVES = clamp(head_ratio // MMA_VEC_SIZE, 1, 4) @@ -213,7 +217,7 @@ def phase_1_constraints() -> list[tkw.Constraint]: def get_constraints(phase: Phase) -> list[tkw.Constraint]: if phase == Phase.PHASE_0: - if mha: + if multi_head_attention: return phase_0_constraints_mha() else: return phase_0_constraints() @@ -413,7 +417,7 @@ def repeat( elements_per_thread=1, # TODO: cannot remove this yet as vector shapes are inferred incorrectly ) - if mha: + if multi_head_attention: symbols_0 = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, BLOCK_B: 1, diff --git a/tests/kernel/wave/attention/paged_attention_test.py b/tests/kernel/wave/attention/paged_attention_test.py index 8c309f58f..b8eb8d44d 100644 --- a/tests/kernel/wave/attention/paged_attention_test.py +++ b/tests/kernel/wave/attention/paged_attention_test.py @@ -447,7 +447,6 @@ def testPagedFlashDecodingMHA( num_kv_splits, input_dtype=dtype, output_dtype=dtype, - mha=True, ) hyperparams_0.update(get_default_scheduling_params()) hyperparams_1.update(get_default_scheduling_params())