From 892e93535f9d0941286fbf44774fbd555de02719 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 8 Jul 2025 12:27:54 -0500 Subject: [PATCH 1/3] [Wave] Remove `mha` param from paged decode attention Can be derived from `shape.num_query_heads == shape.num_kv_heads`, no need for user to specify. Signed-off-by: Paul Zhang --- .../kernel/wave/templates/paged_decode_attention.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/iree/turbine/kernel/wave/templates/paged_decode_attention.py b/iree/turbine/kernel/wave/templates/paged_decode_attention.py index d4e6eb795..18bfd0205 100644 --- a/iree/turbine/kernel/wave/templates/paged_decode_attention.py +++ b/iree/turbine/kernel/wave/templates/paged_decode_attention.py @@ -54,11 +54,9 @@ def get_paged_decode_attention_kernels( input_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16, layer_scaling: Optional[float] = None, - mha: bool = False, logit_cap: float = 0.0, ): - if mha: - assert shape.num_query_heads == shape.num_kv_heads + multi_head_attention = shape.num_query_heads == shape.num_kv_heads wave_input_dtype = torch_dtype_to_wave(input_dtype) wave_output_dtype = torch_dtype_to_wave(output_dtype) @@ -76,7 +74,7 @@ def get_paged_decode_attention_kernels( SPLIT_LEN = tkl.sym.SPLIT_LEN SPLITS_ACTIVE = tkl.sym.SPLITS_ACTIVE U = tkl.sym.U # Num splits - if mha: + if multi_head_attention: BH = B else: BH = tkl.sym.BH @@ -100,7 +98,7 @@ class Phase(Enum): PHASE_1_BLOCK_N = 16 head_ratio = shape.num_query_heads // shape.num_kv_heads MMA_VEC_SIZE = 16 # TODO: Actual value depends in mma type - if mha: + if multi_head_attention: B_WAVES = 1 else: B_WAVES = clamp(head_ratio // MMA_VEC_SIZE, 1, 4) @@ -219,7 +217,7 @@ def phase_1_constraints() -> list[tkw.Constraint]: def get_constraints(phase: Phase) -> list[tkw.Constraint]: if phase == Phase.PHASE_0: - if mha: + if multi_head_attention: return phase_0_constraints_mha() else: return phase_0_constraints() @@ -419,7 +417,7 @@ def repeat( elements_per_thread=1, # TODO: cannot remove this yet as vector shapes are inferred incorrectly ) - if mha: + if multi_head_attention: symbols_0 = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, BLOCK_B: 1, From f19c2d0aaed93bd53324d01b2cdf5d7004a2c382 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 8 Jul 2025 12:37:15 -0500 Subject: [PATCH 2/3] Document support for multi-head/multi-query/grouped-query attention Signed-off-by: Paul Zhang --- .../turbine/kernel/wave/templates/paged_decode_attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/iree/turbine/kernel/wave/templates/paged_decode_attention.py b/iree/turbine/kernel/wave/templates/paged_decode_attention.py index 18bfd0205..88e988016 100644 --- a/iree/turbine/kernel/wave/templates/paged_decode_attention.py +++ b/iree/turbine/kernel/wave/templates/paged_decode_attention.py @@ -56,6 +56,12 @@ def get_paged_decode_attention_kernels( layer_scaling: Optional[float] = None, logit_cap: float = 0.0, ): + """ + Supports multi-head attention (MHA), multi-query attention (MQA), and + grouped-query attention (GQA) depending on the number of query heads + compared to the number of key-value heads. + """ + multi_head_attention = shape.num_query_heads == shape.num_kv_heads wave_input_dtype = torch_dtype_to_wave(input_dtype) From 5622e240a0af9e7ad57af969026b8aba34adbb4b Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 8 Jul 2025 13:08:48 -0500 Subject: [PATCH 3/3] Fix Wave decode attention test Signed-off-by: Paul Zhang --- tests/kernel/wave/attention/paged_attention_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernel/wave/attention/paged_attention_test.py b/tests/kernel/wave/attention/paged_attention_test.py index 8c309f58f..b8eb8d44d 100644 --- a/tests/kernel/wave/attention/paged_attention_test.py +++ b/tests/kernel/wave/attention/paged_attention_test.py @@ -447,7 +447,6 @@ def testPagedFlashDecodingMHA( num_kv_splits, input_dtype=dtype, output_dtype=dtype, - mha=True, ) hyperparams_0.update(get_default_scheduling_params()) hyperparams_1.update(get_default_scheduling_params())