Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions .github/workflows/flash_attention_integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ on:
branches: [main]
paths:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/triton/__init__.py'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
- 'setup.py'
- 'aiter/__init__.py'
- '.github/workflows/flash_attention_integration.yaml'
pull_request:
branches: [main]
paths:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/triton/__init__.py'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
- 'setup.py'
- 'aiter/__init__.py'
- '.github/workflows/flash_attention_integration.yaml'
workflow_dispatch:

Expand All @@ -26,9 +30,8 @@ concurrency:
cancel-in-progress: ${{ github.event_name != 'push' }}

env:
# TODO: Switch to Dao-AILab/flash-attention main
FA_BRANCH: micmelesse/aiter_migration
FA_REPOSITORY_URL: https://github.com/ROCm/flash-attention.git
FA_BRANCH: main
FA_REPOSITORY_URL: https://github.com/Dao-AILab/flash-attention.git
BASE_IMAGE: rocm/pytorch:latest@sha256:683765a52c61341e1674fe730ab3be861a444a45a36c0a8caae7653a08a0e208
AITER_SUBMODULE_PATH: third_party/aiter

Expand Down Expand Up @@ -62,9 +65,11 @@ jobs:
filters: |
shared:
- 'setup.py'
- 'aiter/__init__.py'
- '.github/workflows/flash_attention_integration.yaml'
triton:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/triton/__init__.py'
ck:
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
Expand Down Expand Up @@ -100,6 +105,8 @@ jobs:
label: MI355
- runner: aiter-gfx1100
label: RDNA3
optional: true
continue-on-error: ${{ matrix.optional || false }}

steps:
- name: Checkout aiter repo
Expand Down Expand Up @@ -210,6 +217,59 @@ jobs:
docker rm -f fa_triton_test || true
docker rmi fa_triton_test:ci || true

# =============================================================================
# Windows Smoke Test
# =============================================================================
windows_smoke:
Comment thread
micmelesse marked this conversation as resolved.
if: ${{ needs.prechecks.outputs.run_triton == 'true' }}
name: Windows Smoke Test
needs: [prechecks]
runs-on: windows-latest
steps:
- name: Checkout aiter repo
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install --upgrade pip setuptools wheel
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install triton-windows
shell: bash

- name: Install aiter
run: pip install -e . --no-build-isolation
shell: bash

- name: Install flash-attention (Triton backend)
run: |
git clone -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} "$RUNNER_TEMP/flash-attention"
rm -rf "$RUNNER_TEMP/flash-attention/${{ env.AITER_SUBMODULE_PATH }}"
cp -a . "$RUNNER_TEMP/flash-attention/${{ env.AITER_SUBMODULE_PATH }}"
rm -rf "$RUNNER_TEMP/flash-attention/.git"
cd "$RUNNER_TEMP/flash-attention"
BUILD_TARGET=rocm FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE pip install --no-build-isolation --no-deps -e .
shell: bash

- name: Verify FA imports aiter Triton backend
run: |
cd "$RUNNER_TEMP/flash-attention"
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE python -c "from flash_attn.flash_attn_interface import flash_attn_func; print('flash_attn_func loaded:', flash_attn_func)"

# Simulate partial ROCm install (fake hipconfig + ROCM_HOME) to catch OSError from cpp_extension.py
mkdir -p "$RUNNER_TEMP/fake_rocm/bin"
printf '#!/bin/bash\necho 6.4.0\n' > "$RUNNER_TEMP/fake_rocm/bin/hipconfig"
chmod +x "$RUNNER_TEMP/fake_rocm/bin/hipconfig"
ROCM_HOME="$RUNNER_TEMP/fake_rocm" PATH="$RUNNER_TEMP/fake_rocm/bin:$PATH" FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
python -c "from flash_attn.flash_attn_interface import flash_attn_func; print('flash_attn_func loaded (fake ROCm):', flash_attn_func)"
shell: bash

# =============================================================================
# CK Backend
# =============================================================================
Expand Down
92 changes: 52 additions & 40 deletions aiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import os
import sys
import logging

logger = logging.getLogger("aiter")
Expand Down Expand Up @@ -57,44 +58,55 @@ def getLogger():

logger = getLogger()

from .jit import core as core # noqa: E402
from .utility import dtypes as dtypes # noqa: E402
from .ops.enum import * # noqa: F403,E402
from .ops.norm import * # noqa: F403,E402
from .ops.quant import * # noqa: F403,E402
from .ops.gemm_op_a8w8 import * # noqa: F403,E402
from .ops.gemm_op_a16w16 import * # noqa: F403,E402
from .ops.gemm_op_a4w4 import * # noqa: F403,E402
from .ops.batched_gemm_op_a8w8 import * # noqa: F403,E402
from .ops.batched_gemm_op_bf16 import * # noqa: F403,E402
from .ops.deepgemm import * # noqa: F403,E402
from .ops.aiter_operator import * # noqa: F403,E402
from .ops.activation import * # noqa: F403,E402
from .ops.attention import * # noqa: F403,E402
from .ops.custom import * # noqa: F403,E402
from .ops.custom_all_reduce import * # noqa: F403,E402
from .ops.quick_all_reduce import * # noqa: F403,E402
from .ops.moe_op import * # noqa: F403,E402
from .ops.moe_sorting import * # noqa: F403,E402
from .ops.moe_sorting_opus import * # noqa: F403,E402
from .ops.pos_encoding import * # noqa: F403,E402
from .ops.cache import * # noqa: F403,E402
from .ops.rmsnorm import * # noqa: F403,E402
from .ops.communication import * # noqa: F403,E402
from .ops.rope import * # noqa: F403,E402
from .ops.topk import * # noqa: F403,E402
from .ops.topk_plain import topk_plain # noqa: F403,F401,E402
from .ops.mha import * # noqa: F403,E402
from .ops.gradlib import * # noqa: F403,E402
from .ops.trans_ragged_layout import * # noqa: F403,E402
from .ops.sample import * # noqa: F403,E402
from .ops.fused_qk_norm_mrope_cache_quant import * # noqa: F403,E402
from .ops.fused_qk_norm_rope_cache_quant import * # noqa: F403,E402
from .ops.groupnorm import * # noqa: F403,E402
from .ops.mhc import * # noqa: F403,E402
from .ops.causal_conv1d import * # noqa: F403,E402
from .ops.fused_split_gdr_update import * # noqa: F403,E402
from . import mla # noqa: F403,F401,E402

if sys.platform == "win32":
logger.info("Windows: CK and HIP ops are not available. Triton ops only.")
else:
try:
from .jit import core as core # noqa: E402
from .utility import dtypes as dtypes # noqa: E402
from .ops.enum import * # noqa: F403,E402
from .ops.norm import * # noqa: F403,E402
from .ops.quant import * # noqa: F403,E402
from .ops.gemm_op_a8w8 import * # noqa: F403,E402
from .ops.gemm_op_a16w16 import * # noqa: F403,E402
from .ops.gemm_op_a4w4 import * # noqa: F403,E402
from .ops.batched_gemm_op_a8w8 import * # noqa: F403,E402
from .ops.batched_gemm_op_bf16 import * # noqa: F403,E402
from .ops.deepgemm import * # noqa: F403,E402
from .ops.aiter_operator import * # noqa: F403,E402
from .ops.activation import * # noqa: F403,E402
from .ops.attention import * # noqa: F403,E402
from .ops.custom import * # noqa: F403,E402
from .ops.custom_all_reduce import * # noqa: F403,E402
from .ops.quick_all_reduce import * # noqa: F403,E402
from .ops.moe_op import * # noqa: F403,E402
from .ops.moe_sorting import * # noqa: F403,E402
from .ops.moe_sorting_opus import * # noqa: F403,E402
from .ops.pos_encoding import * # noqa: F403,E402
from .ops.cache import * # noqa: F403,E402
from .ops.rmsnorm import * # noqa: F403,E402
from .ops.communication import * # noqa: F403,E402
from .ops.rope import * # noqa: F403,E402
from .ops.topk import * # noqa: F403,E402
from .ops.topk_plain import topk_plain # noqa: F403,F401,E402
from .ops.mha import * # noqa: F403,E402
from .ops.gradlib import * # noqa: F403,E402
from .ops.trans_ragged_layout import * # noqa: F403,E402
from .ops.sample import * # noqa: F403,E402
from .ops.fused_qk_norm_mrope_cache_quant import * # noqa: F403,E402
from .ops.fused_qk_norm_rope_cache_quant import * # noqa: F403,E402
from .ops.groupnorm import * # noqa: F403,E402
from .ops.mhc import * # noqa: F403,E402
from .ops.causal_conv1d import * # noqa: F403,E402
from .ops.fused_split_gdr_update import * # noqa: F403,E402
from . import mla # noqa: F403,F401,E402
except (ImportError, RuntimeError, OSError) as e:
logger.warning(
"ROCm/HIP JIT runtime not available: %s. "
"CK and HIP ops are disabled. Triton ops remain available.",
e,
)

# Import Triton-based communication primitives from ops.triton.comms (optional, only if Iris is available)
try:
Expand All @@ -106,6 +118,6 @@ def getLogger():
reduce_scatter_rmsnorm_quant_all_gather, # noqa: F401
IRIS_COMM_AVAILABLE, # noqa: F401
)
except ImportError:
# Iris not available, skip import
except (ImportError, AttributeError):
# Iris or triton not available, skip import
IRIS_COMM_AVAILABLE = False
4 changes: 0 additions & 4 deletions aiter/jit/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def _join_rocm_home(*paths) -> str:
"ROCM_HOME environment variable is not set. "
"Please set it to your ROCm install root."
)
elif IS_WINDOWS:
raise OSError(
"Building PyTorch extensions using " "ROCm and Windows is not supported."
)
return os.path.join(ROCM_HOME, *paths)


Expand Down
10 changes: 8 additions & 2 deletions aiter/ops/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import sys
from types import SimpleNamespace

from . import quant
# Try to import quant module
try:
from . import quant
except (ImportError, AttributeError):
quant = None

# Try to import comms module (requires iris)
try:
Expand All @@ -27,7 +31,9 @@
IRIS_COMM_AVAILABLE = False
comms = None

__all__ = ["quant"]
__all__ = []
if quant is not None:
__all__.append("quant")

if _COMMS_AVAILABLE:
__all__.extend(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,11 @@ def is_hip() -> bool:
@functools.cache
def get_arch() -> GpuArch:
"""Get the current GPU architecture."""
name: str = triton.runtime.driver.active.get_current_target().arch
try:
name: str = triton.runtime.driver.active.get_current_target().arch
except RuntimeError:
# No GPU available (e.g. import-only on Windows/CPU)
return GpuArch(name="unknown")
if name in CDNA_ARCHS:
return GpuArch(name=name, family="cdna")
elif name in RDNA_ARCHS:
Expand Down
34 changes: 23 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
PREBUILD_KERNELS = int(os.environ.get("PREBUILD_KERNELS", 0))
ENABLE_CK = int(os.environ.get("ENABLE_CK", "1"))
IS_WINDOWS = sys.platform == "win32"
if IS_WINDOWS:
ENABLE_CK = False
PREBUILD_KERNELS = False


def getMaxJobs():
Expand Down Expand Up @@ -63,7 +67,10 @@ def prepare_packaging():
shutil.copytree("3rdparty", "aiter_meta/3rdparty")
else:
os.makedirs("aiter_meta/3rdparty", exist_ok=True)
shutil.copytree("hsa", "aiter_meta/hsa")
if not IS_WINDOWS:
shutil.copytree("hsa", "aiter_meta/hsa")
else:
os.makedirs("aiter_meta/hsa", exist_ok=True)
shutil.copytree("gradlib", "aiter_meta/gradlib")
shutil.copytree("csrc", "aiter_meta/csrc")
open("aiter_meta/__init__.py", "w").close()
Expand Down Expand Up @@ -98,7 +105,7 @@ def _is_metadata_only():


# Defer heavy imports until build time
if not _is_metadata_only():
if not _is_metadata_only() and not IS_WINDOWS:
import json
from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -306,6 +313,19 @@ def has_ext_modules(self):
return True


if IS_WINDOWS:
install_requires = ["einops", "packaging", "psutil"]
else:
install_requires = [
"pybind11>=3.0.1",
"ninja",
"pandas",
"einops",
"psutil",
"packaging",
"flydsl==0.1.1.dev409",
]

setup(
name=PACKAGE_NAME,
use_scm_version=True,
Expand All @@ -321,15 +341,7 @@ def has_ext_modules(self):
],
cmdclass={"build_ext": NinjaBuildExtension},
python_requires=">=3.8",
install_requires=[
"pybind11>=3.0.1",
"ninja",
"pandas",
"einops",
"psutil",
"packaging",
"flydsl==0.1.1.dev409",
],
install_requires=install_requires,
extras_require={
# Triton-based communication using Iris
# Note: Iris is not available on PyPI and must be installed separately
Expand Down
Loading