diff --git a/aiter/ops/triton/gluon/pa_mqa_logits.py b/aiter/ops/triton/gluon/pa_mqa_logits.py index a1421a6f70..f354ae2091 100644 --- a/aiter/ops/triton/gluon/pa_mqa_logits.py +++ b/aiter/ops/triton/gluon/pa_mqa_logits.py @@ -282,9 +282,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits( context_idx + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) ), - mask=context_idx - + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) - >= 0, + mask=( + ( + context_idx + + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) + >= 0 + ) + & ( + context_idx + + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ) + ), ) context_idx = split_context_start + split_context_length - ChunkK @@ -310,8 +319,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits( ptr=OutLogits_buffer, offsets=(pid_batch * next_n + pid_next_n) * stride_out_batch + (context_idx + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout))), - mask=context_idx + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) - >= 0, + mask=( + ( + context_idx + + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) + >= 0 + ) + & ( + context_idx + + gl.arange(0, ChunkK, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ) + ), ) @@ -595,9 +614,16 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( context_idx + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), - mask=context_idx - + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) - >= split_context_start, + mask=( + context_idx + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + >= split_context_start + ) + & ( + context_idx + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) for context_idx in range( @@ -663,10 +689,22 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), - mask=context_idx - + ChunkKPerStage - + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) - >= split_context_start, + mask=( + context_idx + + ChunkKPerStage + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + >= split_context_start + ) + & ( + context_idx + + ChunkKPerStage + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) # ======================================================================================= @@ -735,6 +773,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), + mask=( + context_idx + + ChunkK + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) context_idx = split_context_start + split_context_length - ChunkK @@ -769,10 +815,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( + ChunkKPerStage + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), - mask=context_idx - + ChunkKPerStage - + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) - >= split_context_start, + mask=( + context_idx + + ChunkKPerStage + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + >= split_context_start + ) + & ( + context_idx + + ChunkKPerStage + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) else: context_idx = split_context_start @@ -925,6 +979,11 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( context_idx + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), + mask=( + context_idx + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) for context_idx_ in range( @@ -1000,6 +1059,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), + mask=( + context_idx_ + + ChunkKPerStage + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) # ======================================================================================= @@ -1074,6 +1141,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), + mask=( + context_idx_ + + ChunkK + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) context_idx = context_idx_ + ChunkK @@ -1107,6 +1182,12 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle( + ChunkKPerStage + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), + mask=( + context_idx + + ChunkKPerStage + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) @@ -1407,9 +1488,16 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( context_idx + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), - mask=context_idx - + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) - >= split_context_start, + mask=( + context_idx + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + >= split_context_start + ) + & ( + context_idx + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) for context_idx in range( @@ -1475,10 +1563,22 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), - mask=context_idx - + ChunkKPerStage - + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) - >= split_context_start, + mask=( + context_idx + + ChunkKPerStage + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + >= split_context_start + ) + & ( + context_idx + + ChunkKPerStage + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) # ======================================================================================= @@ -1547,6 +1647,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), + mask=( + context_idx + + ChunkK + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) context_idx = split_context_start + split_context_length - ChunkK @@ -1581,10 +1689,18 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( + ChunkKPerStage + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), - mask=context_idx - + ChunkKPerStage - + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) - >= split_context_start, + mask=( + context_idx + + ChunkKPerStage + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + >= split_context_start + ) + & ( + context_idx + + ChunkKPerStage + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) else: context_idx = split_context_start @@ -1737,6 +1853,11 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( context_idx + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), + mask=( + context_idx + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), ) for context_idx_ in range( @@ -1812,6 +1933,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), + mask=( + context_idx_ + + ChunkKPerStage + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) # ======================================================================================= @@ -1886,6 +2015,14 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) ) ), + mask=( + context_idx_ + + ChunkK + + gl.arange( + 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) + ) + < max_model_len + ), ) context_idx = context_idx_ + ChunkK @@ -1919,4 +2056,10 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx( + ChunkKPerStage + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) ), + mask=( + context_idx + + ChunkKPerStage + + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) + < max_model_len + ), )