Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/aiter-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ jobs:

steps:
- name: Define whether runs on MI35X
env:
PR_TITLE: ${{ github.event.pull_request.title }}
id: machines
run: |
set -euo pipefail
pr_title="${{ github.event.pull_request.title }}"
if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then
echo "It's main branch, running tests on MI325 and MI35X..."
echo 'standard_runners=["aiter-mi355-1gpu"]' >> "$GITHUB_OUTPUT"
echo 'multigpu_runners=["aiter-mi355-8gpu"]' >> "$GITHUB_OUTPUT"
elif echo "$pr_title" | grep -qi "mi35x"; then
elif echo "${PR_TITLE}" | grep -qi "mi35x"; then
echo "PR title contains 'MI35X', running tests on MI325 and MI35X..."
echo 'standard_runners=["aiter-mi355-1gpu"]' >> "$GITHUB_OUTPUT"
echo 'multigpu_runners=["aiter-mi355-8gpu"]' >> "$GITHUB_OUTPUT"
Expand Down
5 changes: 3 additions & 2 deletions aiter/ops/triton/mha_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_3
from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype


class _FlashAttnV3Func(torch.autograd.Function):
Expand Down Expand Up @@ -718,7 +719,7 @@ def forward(
_, _, num_kv_heads, _ = k.shape

# Quantize inputs to FP8
fp8_dtype = torch.float8_e4m3fnuz
fp8_dtype = get_fp8_e4m3_dtype()

# For GQA/MQA: quantize query with grouped scaling
group_size = (
Expand Down Expand Up @@ -1002,7 +1003,7 @@ def forward(
num_kv_heads = k.shape[1]

# Quantize inputs to FP8 using _quantize_thd for varlen tensors
fp8_dtype = torch.float8_e4m3fnuz
fp8_dtype = get_fp8_e4m3_dtype()

# For GQA/MQA: quantize query with grouped scaling
group_size = (
Expand Down
Loading