diff --git a/.github/workflows/aiter-test.yaml b/.github/workflows/aiter-test.yaml index a674437d3d..7004e960e8 100644 --- a/.github/workflows/aiter-test.yaml +++ b/.github/workflows/aiter-test.yaml @@ -36,15 +36,16 @@ jobs: steps: - name: Define whether runs on MI35X + env: + PR_TITLE: ${{ github.event.pull_request.title }} id: machines run: | set -euo pipefail - pr_title="${{ github.event.pull_request.title }}" if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then echo "It's main branch, running tests on MI325 and MI35X..." echo 'standard_runners=["aiter-mi355-1gpu"]' >> "$GITHUB_OUTPUT" echo 'multigpu_runners=["aiter-mi355-8gpu"]' >> "$GITHUB_OUTPUT" - elif echo "$pr_title" | grep -qi "mi35x"; then + elif echo "${PR_TITLE}" | grep -qi "mi35x"; then echo "PR title contains 'MI35X', running tests on MI325 and MI35X..." echo 'standard_runners=["aiter-mi355-1gpu"]' >> "$GITHUB_OUTPUT" echo 'multigpu_runners=["aiter-mi355-8gpu"]' >> "$GITHUB_OUTPUT" diff --git a/aiter/ops/triton/mha_v3.py b/aiter/ops/triton/mha_v3.py index 459a28fcdd..c2fa24c769 100644 --- a/aiter/ops/triton/mha_v3.py +++ b/aiter/ops/triton/mha_v3.py @@ -6,6 +6,7 @@ import torch from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_3 +from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype class _FlashAttnV3Func(torch.autograd.Function): @@ -718,7 +719,7 @@ def forward( _, _, num_kv_heads, _ = k.shape # Quantize inputs to FP8 - fp8_dtype = torch.float8_e4m3fnuz + fp8_dtype = get_fp8_e4m3_dtype() # For GQA/MQA: quantize query with grouped scaling group_size = ( @@ -1002,7 +1003,7 @@ def forward( num_kv_heads = k.shape[1] # Quantize inputs to FP8 using _quantize_thd for varlen tensors - fp8_dtype = torch.float8_e4m3fnuz + fp8_dtype = get_fp8_e4m3_dtype() # For GQA/MQA: quantize query with grouped scaling group_size = (