diff --git a/.github/scripts/build_aiter_triton.sh b/.github/scripts/build_aiter_triton.sh index 66ce88939d..3965c3a44d 100755 --- a/.github/scripts/build_aiter_triton.sh +++ b/.github/scripts/build_aiter_triton.sh @@ -15,6 +15,7 @@ pip install --upgrade "pybind11>=3.0.1" pip install --upgrade "ninja>=1.11.1" pip install tabulate pip install -e . +./.github/scripts/install_triton.sh # Read BUILD_TRITON env var, default to 1. If 1, install Triton; if 0, skip installation. BUILD_TRITON=${BUILD_TRITON:-1} diff --git a/.github/scripts/install_triton.sh b/.github/scripts/install_triton.sh new file mode 100755 index 0000000000..035650bb44 --- /dev/null +++ b/.github/scripts/install_triton.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +pip uninstall -y triton pytorch-triton pytorch-triton-rocm triton-rocm amd-triton || true + +TRITON_INDEX_URL="https://pypi.amd.com/triton/rocm-7.0.0/simple/" +ROCM_VERSION=$(dpkg -l rocm-core 2>/dev/null | awk '/^ii/{print $3}') +if [[ -n "$ROCM_VERSION" ]]; then + ROCM_MAJOR_MINOR=$(echo "$ROCM_VERSION" | cut -d. -f1,2) + TRITON_INDEX_URL="https://pypi.amd.com/triton/rocm-${ROCM_MAJOR_MINOR}.0/simple/" +fi + +echo "Installing amd-triton from $TRITON_INDEX_URL" +pip install --extra-index-url "$TRITON_INDEX_URL" amd-triton diff --git a/.github/scripts/split_tests.sh b/.github/scripts/split_tests.sh index 1d51962896..05e58b9626 100755 --- a/.github/scripts/split_tests.sh +++ b/.github/scripts/split_tests.sh @@ -170,6 +170,16 @@ elif [[ "$TEST_TYPE" == "triton" ]]; then FILE_TIMES[op_tests/triton_tests/gemm/basic/test_gemm_a8w8_per_token_scale.py]=17 FILE_TIMES[op_tests/triton_tests/quant/test_fused_fp8_quant.py]=17 FILE_TIMES[op_tests/triton_tests/test_gather_kv_b_proj.py]=16 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_gemm_a16w16.py]=19 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_activation.py]=11 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_moe_routing.py]=10 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_rope.py]=9 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_softmax.py]=8 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_fused_mul_add.py]=7 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_quant_per_tensor.py]=7 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_quant_per_token.py]=7 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_rmsnorm.py]=7 + FILE_TIMES[op_tests/triton_tests/torch_compile/test_compile_topk.py]=5 FILE_TIMES[op_tests/triton_tests/attention/test_extend_attention.py]=7 FILE_TIMES[op_tests/triton_tests/fusions/test_fused_qk_concat.py]=7 FILE_TIMES[op_tests/triton_tests/gemm/basic/test_gemm_a8w8_blockscale.py]=7 diff --git a/.github/workflows/aiter-test.yaml b/.github/workflows/aiter-test.yaml index 23c9d4b26c..0b88d47c14 100644 --- a/.github/workflows/aiter-test.yaml +++ b/.github/workflows/aiter-test.yaml @@ -380,6 +380,11 @@ jobs: pip show amd-aiter " + - name: Install amd-triton + run: | + docker exec -w /workspace aiter_test \ + bash -c "./.github/scripts/install_triton.sh && pip show amd-triton" + - name: Show Aiter version run: | set -ex @@ -564,6 +569,11 @@ jobs: pip show amd-aiter " + - name: Install amd-triton + run: | + docker exec -w /workspace aiter_test \ + bash -c "./.github/scripts/install_triton.sh && pip show amd-triton" + - name: Show Aiter version run: | set -ex diff --git a/.github/workflows/atom-test.yaml b/.github/workflows/atom-test.yaml index ae554c8abe..af5494b8d6 100644 --- a/.github/workflows/atom-test.yaml +++ b/.github/workflows/atom-test.yaml @@ -112,7 +112,9 @@ jobs: cd /app/aiter-test && \\ git checkout ${{ env.GITHUB_COMMIT_SHA }} && \\ git submodule sync && git submodule update --init --recursive && \\ - MAX_JOBS=64 PREBUILD_KERNELS=0 GPU_ARCHS=gfx950 pip install -e . + MAX_JOBS=64 PREBUILD_KERNELS=0 GPU_ARCHS=gfx950 pip install -e . && \\ + ./.github/scripts/install_triton.sh + RUN echo "=== amd-triton version ===" && pip show amd-triton || true RUN echo "=== Aiter version AFTER installation ===" && pip show amd-aiter || true EOF diff --git a/.github/workflows/sglang_downstream.yaml b/.github/workflows/sglang_downstream.yaml index 3fffb91d7d..b275b2d842 100644 --- a/.github/workflows/sglang_downstream.yaml +++ b/.github/workflows/sglang_downstream.yaml @@ -196,6 +196,8 @@ jobs: git submodule sync --recursive git submodule update --init --recursive pip install -e . + ./.github/scripts/install_triton.sh + pip show amd-triton || true pip show amd-aiter || pip show aiter " diff --git a/.github/workflows/triton-test.yaml b/.github/workflows/triton-test.yaml index f66899ca2b..b678ad232f 100644 --- a/.github/workflows/triton-test.yaml +++ b/.github/workflows/triton-test.yaml @@ -56,82 +56,12 @@ jobs: path: triton_shard_*.list retention-days: 7 - # Build Triton wheel once, shared by all shard jobs via artifact - build-triton: - if: ${{ (!github.event.pull_request || github.event.pull_request.draft == false) && (github.event_name != 'pull_request' || github.event.action != 'labeled' || github.event.label.name == 'ci:triton-300x') }} - name: Build Triton Wheel - runs-on: linux-aiter-mi35x-1 - needs: [check-signal] - env: - DOCKER_IMAGE: "rocm/pytorch:latest" - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 1 - - - name: Docker login - if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.head.repo.fork }} - env: - DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} - run: | - for attempt in 1 2 3; do - if echo "$DOCKER_PASSWORD" | docker login -u rocmshared --password-stdin; then - echo "Docker login succeeded on attempt ${attempt}" - exit 0 - fi - echo "Docker login attempt ${attempt} failed" - if [ "${attempt}" != 3 ]; then - sleep 10 - fi - done - echo "Docker login failed after 3 attempts, continuing anyway" - exit 0 - - - name: Build Triton wheel in Docker - run: | - set -ex - mkdir -p triton-wheels - - if [ -f "/etc/podinfo/gha-render-devices" ]; then - DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices) - else - DEVICE_FLAG="--device /dev/dri" - fi - - docker run --rm \ - --device=/dev/kfd $DEVICE_FLAG \ - --shm-size=16G \ - --group-add $(getent group render | cut -d: -f3) \ - --group-add $(getent group video | cut -d: -f3) \ - -v "${{ github.workspace }}:/workspace" \ - -w /workspace \ - ${{ env.DOCKER_IMAGE }} \ - bash -c ' - set -ex - pip config set global.default-timeout 60 - pip config set global.retries 10 - TRITON_COMMIT=${TRITON_COMMIT:-756afc06} - git clone https://github.com/triton-lang/triton - cd triton - git checkout "$TRITON_COMMIT" - pip install -r python/requirements.txt - MAX_JOBS=64 pip wheel --no-deps -w /workspace/triton-wheels . - ' - - - name: Upload Triton wheel - uses: actions/upload-artifact@v4 - with: - name: triton-wheel - path: triton-wheels/*.whl - retention-days: 7 - # Step 2: MI35X matrix jobs triton: if: ${{ (!github.event.pull_request || github.event.pull_request.draft == false) && (github.event_name != 'pull_request' || github.event.action != 'labeled' || github.event.label.name == 'ci:triton-300x') }} name: Triton Tests (MI35X) / Shard ${{ matrix.shard }} runs-on: linux-aiter-mi35x-1 - needs: [split_triton_tests, build-triton, check-signal] + needs: [split_triton_tests, check-signal] strategy: fail-fast: false matrix: @@ -152,12 +82,6 @@ jobs: with: name: triton_shards - - name: Download Triton wheel - uses: actions/download-artifact@v4 - with: - name: triton-wheel - path: triton-wheels - - name: List test shard files run: | ls -l triton_shard_*.list @@ -218,7 +142,7 @@ jobs: set -ex echo "Setting up Aiter and Triton..." docker exec \ - -e TRITON_WHEEL_DIR=/workspace/triton-wheels \ + -e BUILD_TRITON=0 \ -w /workspace \ triton_test \ ./.github/scripts/build_aiter_triton.sh @@ -279,7 +203,7 @@ jobs: if: ${{ (!github.event.pull_request || github.event.pull_request.draft == false) && (github.event_name != 'pull_request' || github.event.action != 'labeled' || github.event.label.name == 'ci:triton-300x') && (github.ref == 'refs/heads/main' || (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'ci:triton-300x'))) }} name: Triton Tests (MI300X) / Shard ${{ matrix.shard }} runs-on: linux-aiter-mi300x-1 - needs: [split_triton_tests, build-triton, check-signal] + needs: [split_triton_tests, check-signal] strategy: fail-fast: false matrix: @@ -300,12 +224,6 @@ jobs: with: name: triton_shards - - name: Download Triton wheel - uses: actions/download-artifact@v4 - with: - name: triton-wheel - path: triton-wheels - - name: List test shard files run: | ls -l triton_shard_*.list @@ -366,7 +284,7 @@ jobs: set -ex echo "Setting up Aiter and Triton..." docker exec \ - -e TRITON_WHEEL_DIR=/workspace/triton-wheels \ + -e BUILD_TRITON=0 \ -w /workspace \ triton_test \ ./.github/scripts/build_aiter_triton.sh diff --git a/.github/workflows/vllm_benchmark.yaml b/.github/workflows/vllm_benchmark.yaml index 8237a5fb3c..c8cab294ee 100644 --- a/.github/workflows/vllm_benchmark.yaml +++ b/.github/workflows/vllm_benchmark.yaml @@ -96,7 +96,9 @@ jobs: cd /aiter && \\ git checkout ${{ env.GITHUB_COMMIT_SHA }} && \\ git submodule sync && git submodule update --init --recursive && \\ - pip install -e . + pip install -e . && \\ + ./.github/scripts/install_triton.sh + RUN echo "=== amd-triton version ===" && pip show amd-triton || true RUN echo "=== Aiter version AFTER installation ===" && pip show amd-aiter || true diff --git a/README.md b/README.md index 715b21d48d..d1da5190ab 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,16 @@ Or install all optional dependencies at once: pip install -r requirements.txt ``` +### Triton + +AITER includes Triton-based operators that require amd-triton ([ROCm 7.0](https://pypi.amd.com/triton/rocm-7.0.0/simple/), [ROCm 7.1](https://pypi.amd.com/triton/rocm-7.1.0/simple/), [ROCm 7.2](https://pypi.amd.com/triton/rocm-7.2.0/simple/)), with the correct version selected based on your ROCm installation. + +If you install with `python3 setup.py develop`, amd-triton is installed automatically. If you use `pip install -e .`, run the install script manually: + +```bash +./.github/scripts/install_triton.sh +``` + ### Opus — Lightweight C++ Template for Kernel Development [Opus](csrc/include/opus/) is a single-header C++ template library (`opus.hpp`) for writing HIP kernels on AMD GPUs — vectorized load/store, layout abstractions, and MFMA wrappers with a strong focus on **build time optimization** (up to 61x faster than standard torch extension builds). See the [Opus README](csrc/include/opus/README.md) and [`op_tests/opus/`](op_tests/opus/) for details. diff --git a/op_tests/triton_tests/torch_compile/__init__.py b/op_tests/triton_tests/torch_compile/__init__.py new file mode 100644 index 0000000000..95fe2526e9 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch + + +def _get_compiled(fn): + return torch.compile( + fn, backend="inductor", fullgraph=True, options={"max_autotune": True} + ) diff --git a/op_tests/triton_tests/torch_compile/test_compile_activation.py b/op_tests/triton_tests/torch_compile/test_compile_activation.py new file mode 100644 index 0000000000..40f62cf20d --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_activation.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import torch.nn.functional as F + +from . import _get_compiled + + +def torch_silu_mul(x): + half = x.shape[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + return F.silu(x1) * x2 + + +def torch_gelu_mul(x): + half = x.shape[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + return F.gelu(x1) * x2 + + +@pytest.mark.parametrize("activation", ["silu", "gelu"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N", [(64, 256), (128, 512), (256, 1024)]) +def test_compile_activation(M, N, dtype, activation): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, N, device="cuda", dtype=dtype) + + act_fn = torch_silu_mul if activation == "silu" else torch_gelu_mul + out_eager = act_fn(x) + + compiled_fn = _get_compiled(act_fn) + out_compiled = compiled_fn(x) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + tol = (0.1, 0.1) if dtype == torch.bfloat16 else (1e-3, 1e-3) + torch.testing.assert_close(out_compiled, out_eager, atol=tol[0], rtol=tol[1]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_fused_mul_add.py b/op_tests/triton_tests/torch_compile/test_compile_fused_mul_add.py new file mode 100644 index 0000000000..f4e0f2dd83 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_fused_mul_add.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + + +def torch_fused_mul_add(x, a, b): + return x * a + b + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N", [(128, 256), (256, 512), (512, 1024)]) +@pytest.mark.parametrize("scalar_ab", [False, True]) +def test_compile_fused_mul_add(M, N, dtype, scalar_ab): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, N, device="cuda", dtype=dtype) + + if scalar_ab: + a, b = 2.0, 0.5 + else: + a = torch.randn(M, N, device="cuda", dtype=dtype) + b = torch.randn(M, N, device="cuda", dtype=dtype) + + out_eager = torch_fused_mul_add(x, a, b) + + def fn(x, a, b): + return torch_fused_mul_add(x, a, b) + + compiled_fn = _get_compiled(fn) + out_compiled = compiled_fn(x, a, b) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + tol = (0.1, 0.1) if dtype == torch.bfloat16 else (1e-3, 1e-3) + torch.testing.assert_close(out_compiled, out_eager, atol=tol[0], rtol=tol[1]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_gemm_a16w16.py b/op_tests/triton_tests/torch_compile/test_compile_gemm_a16w16.py new file mode 100644 index 0000000000..8983323199 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_gemm_a16w16.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + + +def torch_gemm(x, w, bias=None): + out = torch.mm(x, w.t()) + if bias is not None: + out = out + bias + return out + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N, K", [(128, 256, 64), (256, 512, 128), (512, 1024, 256)]) +def test_compile_gemm(M, N, K, dtype): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, K, device="cuda", dtype=dtype) + w = torch.randn(N, K, device="cuda", dtype=dtype) + + out_eager = torch_gemm(x, w) + + def fn(x, w): + return torch_gemm(x, w) + + compiled_fn = _get_compiled(fn) + out_compiled = compiled_fn(x, w) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(out_compiled, out_eager, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N, K", [(128, 256, 64), (256, 512, 128)]) +def test_compile_gemm_with_bias(M, N, K, dtype): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, K, device="cuda", dtype=dtype) + w = torch.randn(N, K, device="cuda", dtype=dtype) + bias = torch.randn(N, device="cuda", dtype=dtype) + + out_eager = torch_gemm(x, w, bias=bias) + + def fn(x, w, bias): + return torch_gemm(x, w, bias=bias) + + compiled_fn = _get_compiled(fn) + out_compiled = compiled_fn(x, w, bias) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(out_compiled, out_eager, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_moe_routing.py b/op_tests/triton_tests/torch_compile/test_compile_moe_routing.py new file mode 100644 index 0000000000..36a91c6aa8 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_moe_routing.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + + +def torch_routing_sigmoid_top1(x, w, topk=1): + logits = x @ w + weights = torch.sigmoid(logits) + topk_weights, topk_ids = torch.topk(weights, topk, dim=-1) + return topk_ids.to(torch.int32), topk_weights.to(torch.float32) + + +@pytest.mark.parametrize("M, K, N", [(64, 128, 8), (128, 256, 16), (256, 512, 32)]) +def test_compile_moe_routing(M, K, N): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + topk = 1 + x = torch.randn(M, K, device="cuda", dtype=torch.float16) + w = torch.randn(K, N, device="cuda", dtype=torch.float16) + + ids_eager, weights_eager = torch_routing_sigmoid_top1(x, w, topk=topk) + + def fn(x, w): + return torch_routing_sigmoid_top1(x, w, topk=topk) + + compiled_fn = _get_compiled(fn) + ids_compiled, weights_compiled = compiled_fn(x, w) + torch.cuda.synchronize() + + assert not torch.isnan(weights_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(ids_compiled, ids_eager) + torch.testing.assert_close(weights_compiled, weights_eager, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_quant_per_tensor.py b/op_tests/triton_tests/torch_compile/test_compile_quant_per_tensor.py new file mode 100644 index 0000000000..28e1e55f73 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_quant_per_tensor.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + +FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max + + +def torch_dynamic_per_tensor_quant_fp8(x): + x_float = x.to(torch.float32) + amax = x_float.abs().amax().clamp(min=1e-12) + scale = amax / FP8_MAX + qx = (x_float / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fnuz) + return qx, scale.unsqueeze(0) + + +@pytest.mark.parametrize("M, N", [(64, 128), (128, 256), (256, 512)]) +def test_compile_quant_per_tensor(M, N): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + + qx_eager, scale_eager = torch_dynamic_per_tensor_quant_fp8(x) + + compiled_fn = _get_compiled(torch_dynamic_per_tensor_quant_fp8) + qx_compiled, scale_compiled = compiled_fn(x) + torch.cuda.synchronize() + + torch.testing.assert_close(scale_compiled, scale_eager, atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + qx_compiled.to(torch.float32), + qx_eager.to(torch.float32), + atol=1.0, + rtol=1e-1, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_quant_per_token.py b/op_tests/triton_tests/torch_compile/test_compile_quant_per_token.py new file mode 100644 index 0000000000..fcd6b1ce96 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_quant_per_token.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + +FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max + + +def torch_dynamic_per_token_quant_fp8(x): + x_float = x.to(torch.float32) + amax = x_float.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = amax / FP8_MAX + qx = (x_float / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fnuz) + return qx, scale + + +@pytest.mark.parametrize("M, N", [(64, 128), (128, 256), (256, 512)]) +def test_compile_quant_per_token(M, N): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + + qx_eager, scale_eager = torch_dynamic_per_token_quant_fp8(x) + + compiled_fn = _get_compiled(torch_dynamic_per_token_quant_fp8) + qx_compiled, scale_compiled = compiled_fn(x) + torch.cuda.synchronize() + + torch.testing.assert_close(scale_compiled, scale_eager, atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + qx_compiled.to(torch.float32), + qx_eager.to(torch.float32), + atol=1.0, + rtol=1e-1, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_rmsnorm.py b/op_tests/triton_tests/torch_compile/test_compile_rmsnorm.py new file mode 100644 index 0000000000..e0fd6270a8 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_rmsnorm.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + + +def torch_rmsnorm(x, weight, eps): + variance = x.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(variance + eps) + return (x_normed * weight).to(x.dtype) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N", [(128, 256), (256, 512), (512, 1024)]) +def test_compile_rmsnorm(M, N, dtype): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + eps = 1e-6 + x = torch.randn(M, N, device="cuda", dtype=dtype) + weight = torch.ones(N, device="cuda", dtype=dtype) + + out_eager = torch_rmsnorm(x, weight, eps) + + def fn(x, weight): + return torch_rmsnorm(x, weight, eps) + + compiled_fn = _get_compiled(fn) + out_compiled = compiled_fn(x, weight) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(out_compiled, out_eager, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_rope.py b/op_tests/triton_tests/torch_compile/test_compile_rope.py new file mode 100644 index 0000000000..7921f2a4fa --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_rope.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + + +def generate_cos_sin(seq_len, dim, device, dtype): + freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, device=device).float() / dim)) + t = torch.arange(seq_len, device=device).float() + freqs = torch.outer(t, freqs) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def torch_rope_neox(x, cos, sin): + dim = x.shape[-1] + half = dim // 2 + x1 = x[..., :half] + x2 = x[..., half:] + cos_part = cos[:, :half].unsqueeze(1).unsqueeze(1) + sin_part = sin[:, :half].unsqueeze(1).unsqueeze(1) + out1 = x1 * cos_part - x2 * sin_part + out2 = x2 * cos_part + x1 * sin_part + return torch.cat([out1, out2], dim=-1) + + +def torch_rope_gptj(x, cos, sin): + dim = x.shape[-1] + half = dim // 2 + x1 = x[..., 0::2] + x2 = x[..., 1::2] + cos_part = cos[:, :half].unsqueeze(1).unsqueeze(1) + sin_part = sin[:, :half].unsqueeze(1).unsqueeze(1) + out1 = x1 * cos_part - x2 * sin_part + out2 = x1 * sin_part + x2 * cos_part + out = torch.stack([out1, out2], dim=-1).flatten(-2) + return out + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("rotate_style", ["neox", "gptj"]) +@pytest.mark.parametrize("S, B, H, D", [(32, 2, 8, 64), (64, 4, 16, 128)]) +def test_compile_rope(S, B, H, D, dtype, rotate_style): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + device = "cuda" + x = torch.randn(S, B, H, D, device=device, dtype=dtype) + cos, sin = generate_cos_sin(S, D, device, dtype) + + rope_fn = torch_rope_neox if rotate_style == "neox" else torch_rope_gptj + out_eager = rope_fn(x, cos, sin) + + def fn(x, cos, sin): + return rope_fn(x, cos, sin) + + compiled_fn = _get_compiled(fn) + out_compiled = compiled_fn(x, cos, sin) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(out_compiled, out_eager, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_softmax.py b/op_tests/triton_tests/torch_compile/test_compile_softmax.py new file mode 100644 index 0000000000..48f2fca588 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_softmax.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import torch.nn.functional as F + +from . import _get_compiled + + +def torch_softmax(x, dim=-1): + return F.softmax(x, dim=dim) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M, N", [(128, 64), (256, 512), (1024, 1024)]) +def test_compile_softmax(M, N, dtype): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, N, device="cuda", dtype=dtype) + + out_eager = torch_softmax(x) + + compiled_fn = _get_compiled(torch_softmax) + out_compiled = compiled_fn(x) + torch.cuda.synchronize() + + assert not torch.isnan(out_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(out_compiled, out_eager, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/op_tests/triton_tests/torch_compile/test_compile_topk.py b/op_tests/triton_tests/torch_compile/test_compile_topk.py new file mode 100644 index 0000000000..2a58d7e794 --- /dev/null +++ b/op_tests/triton_tests/torch_compile/test_compile_topk.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from . import _get_compiled + + +def torch_topk(x, k): + return torch.topk(x, k, dim=-1) + + +@pytest.mark.parametrize("k", [8, 32]) +@pytest.mark.parametrize("M, N", [(64, 256), (128, 512), (256, 1024)]) +def test_compile_topk(M, N, k): + torch.manual_seed(42) + torch.cuda.empty_cache() + torch._dynamo.reset() + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + + values_eager, indices_eager = torch_topk(x, k) + + def fn(x): + return torch_topk(x, k) + + compiled_fn = _get_compiled(fn) + values_compiled, indices_compiled = compiled_fn(x) + torch.cuda.synchronize() + + assert not torch.isnan(values_compiled).any(), "torch.compile produced NaN" + torch.testing.assert_close(values_compiled, values_eager, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(indices_compiled, indices_eager) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/setup.py b/setup.py index 1188c78ac1..a6e9234e7d 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,14 @@ def is_develop_mode(): ] ) + try: + install_triton = os.path.join( + this_dir, ".github", "scripts", "install_triton.sh" + ) + subprocess.check_call(["bash", install_triton]) + except Exception: + pass + def write_install_mode(): """Write install_mode so core.py uses aiter_meta/ (install) vs repo root (develop).