Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 170 additions & 27 deletions aiter/ops/triton/gluon/pa_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits(
context_idx
+ gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
),
mask=context_idx
+ gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
>= 0,
mask=(
(
context_idx
+ gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
>= 0
)
& (
context_idx
+ gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
)
),
)

context_idx = split_context_start + split_context_length - ChunkK
Expand All @@ -310,8 +319,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits(
ptr=OutLogits_buffer,
offsets=(pid_batch * next_n + pid_next_n) * stride_out_batch
+ (context_idx + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))),
mask=context_idx + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
>= 0,
mask=(
(
context_idx
+ gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
>= 0
)
& (
context_idx
+ gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
)
),
)


Expand Down Expand Up @@ -595,9 +614,16 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start
)
& (
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)

for context_idx in range(
Expand Down Expand Up @@ -663,10 +689,22 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
context_idx
+ ChunkKPerStage
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
>= split_context_start
)
& (
context_idx
+ ChunkKPerStage
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)

# =======================================================================================
Expand Down Expand Up @@ -735,6 +773,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=(
context_idx
+ ChunkK
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)

context_idx = split_context_start + split_context_length - ChunkK
Expand Down Expand Up @@ -769,10 +815,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start
)
& (
context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)
else:
context_idx = split_context_start
Expand Down Expand Up @@ -925,6 +979,11 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)

for context_idx_ in range(
Expand Down Expand Up @@ -1000,6 +1059,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=(
context_idx_
+ ChunkKPerStage
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)

# =======================================================================================
Expand Down Expand Up @@ -1074,6 +1141,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=(
context_idx_
+ ChunkK
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)
context_idx = context_idx_ + ChunkK

Expand Down Expand Up @@ -1107,6 +1182,12 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=(
context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)


Expand Down Expand Up @@ -1407,9 +1488,16 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start
)
& (
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)

for context_idx in range(
Expand Down Expand Up @@ -1475,10 +1563,22 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
context_idx
+ ChunkKPerStage
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
>= split_context_start
)
& (
context_idx
+ ChunkKPerStage
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)

# =======================================================================================
Expand Down Expand Up @@ -1547,6 +1647,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=(
context_idx
+ ChunkK
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)

context_idx = split_context_start + split_context_length - ChunkK
Expand Down Expand Up @@ -1581,10 +1689,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start,
mask=(
context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
>= split_context_start
)
& (
context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)
else:
context_idx = split_context_start
Expand Down Expand Up @@ -1737,6 +1853,11 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=(
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)

for context_idx_ in range(
Expand Down Expand Up @@ -1812,6 +1933,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=(
context_idx_
+ ChunkKPerStage
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)

# =======================================================================================
Expand Down Expand Up @@ -1886,6 +2015,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
),
mask=(
context_idx_
+ ChunkK
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
< max_model_len
),
)
context_idx = context_idx_ + ChunkK

Expand Down Expand Up @@ -1919,4 +2056,10 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx(
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
),
mask=(
context_idx
+ ChunkKPerStage
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
< max_model_len
),
)
Loading