Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions iree/turbine/kernel/wave/templates/paged_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,15 @@ def get_paged_decode_attention_kernels(
input_dtype: torch.dtype = torch.float16,
output_dtype: torch.dtype = torch.float16,
layer_scaling: Optional[float] = None,
mha: bool = False,
logit_cap: float = 0.0,
):
if mha:
assert shape.num_query_heads == shape.num_kv_heads
"""
Supports multi-head attention (MHA), multi-query attention (MQA), and
grouped-query attention (GQA) depending on the number of query heads
compared to the number of key-value heads.
"""

multi_head_attention = shape.num_query_heads == shape.num_kv_heads

wave_input_dtype = torch_dtype_to_wave(input_dtype)
wave_output_dtype = torch_dtype_to_wave(output_dtype)
Expand All @@ -76,7 +80,7 @@ def get_paged_decode_attention_kernels(
SPLIT_LEN = tkl.sym.SPLIT_LEN
SPLITS_ACTIVE = tkl.sym.SPLITS_ACTIVE
U = tkl.sym.U # Num splits
if mha:
if multi_head_attention:
BH = B
else:
BH = tkl.sym.BH
Expand All @@ -100,7 +104,7 @@ class Phase(Enum):
PHASE_1_BLOCK_N = 16
head_ratio = shape.num_query_heads // shape.num_kv_heads
MMA_VEC_SIZE = 16 # TODO: Actual value depends in mma type
if mha:
if multi_head_attention:
B_WAVES = 1
else:
B_WAVES = clamp(head_ratio // MMA_VEC_SIZE, 1, 4)
Expand Down Expand Up @@ -213,7 +217,7 @@ def phase_1_constraints() -> list[tkw.Constraint]:

def get_constraints(phase: Phase) -> list[tkw.Constraint]:
if phase == Phase.PHASE_0:
if mha:
if multi_head_attention:
return phase_0_constraints_mha()
else:
return phase_0_constraints()
Expand Down Expand Up @@ -413,7 +417,7 @@ def repeat(
elements_per_thread=1, # TODO: cannot remove this yet as vector shapes are inferred incorrectly
)

if mha:
if multi_head_attention:
symbols_0 = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
BLOCK_B: 1,
Expand Down
1 change: 0 additions & 1 deletion tests/kernel/wave/attention/paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ def testPagedFlashDecodingMHA(
num_kv_splits,
input_dtype=dtype,
output_dtype=dtype,
mha=True,
)
hyperparams_0.update(get_default_scheduling_params())
hyperparams_1.update(get_default_scheduling_params())
Expand Down
Loading