Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 48 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,62 +1,74 @@
# aiter
![image](https://github.com/user-attachments/assets/9457804f-77cd-44b0-a088-992e4b9971c6)


AITER is AMD’s centralized repository that support various of high performance AI operators for AI workloads acceleration, where a good unified place for all the customer operator-level requests, which can match different customers' needs. Developers can focus on operators, and let the customers integrate this op collection into their own private/public/whatever framework.

AITER is AMD's centralized repository that support various of high performance AI operators for AI workloads acceleration, where a good unified place for all the customer operator-level requests, which can match different customers' needs. Developers can focus on operators, and let the customers integrate this op collection into their own private/public/whatever framework.

Some summary of the features:
* C++ level API
* Python level API
* The underneath kernel could come from triton/ck/asm
* Not just inference kernels, but also training kernels and GEMM+communication kernels—allowing for workarounds in any kernel-framework combination for any architecture limitation.



## Installation
```

```bash
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
python3 setup.py develop
```

If you happen to forget the `--recursive` during `clone`, you can use the following command after `cd aiter`
```
If you forgot `--recursive` during clone:
```bash
git submodule sync && git submodule update --init --recursive
```

### Development Mode (JIT)

Kernels are compiled on first use — fastest to get started:
```bash
python3 setup.py develop
```

### Triton-based Communication (Iris)
### Install with Precompiled Kernels

AITER supports GPU-initiated communication using the [Iris library](https://github.com/ROCm/iris). This enables high-performance Triton-based communication primitives like reduce-scatter and all-gather.
Precompile kernels at install time so there is no JIT overhead at runtime:
```bash
PREBUILD_KERNELS=2 GPU_ARCHS="gfx942" python3 setup.py install
```

**Installation**
| Variable | Description |
|---|---|
| `GPU_ARCHS` | Target GPU architecture(s), semicolon-separated. Use `"native"` to auto-detect. Common values: `gfx942` (MI300X), `gfx950` (MI350X), `gfx90a` (MI250X). Multi-target example: `"gfx942;gfx950"` |
| `PREBUILD_KERNELS` | `0` — no precompilation (JIT only, default). `1` — precompile core kernels (excludes tuning and most MHA variants). `2` — precompile inference kernels (excludes backward and tuning). `3` — precompile MHA kernels only (minimal build). |
| `MAX_JOBS` | Max parallel compilation threads (auto-calculated from CPU cores and memory if not set) |

Install with Triton communication support:
### Triton Communication Support (Optional)

For multi-GPU communication primitives (reduce-scatter, all-gather) using the [Iris](https://github.com/ROCm/iris) library:
```bash
# Install AITER with Triton communication dependencies
pip install -e .
pip install -r requirements-triton-comms.txt
```

For more details, see [docs/triton_comms.md](docs/triton_comms.md).

## Run operators supported by aiter

There are number of op test, you can run them with: `python3 op_tests/test_layernorm2d.py`
| **Ops** | **Description** |
|-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|ELEMENT WISE | ops: + - * / |
|SIGMOID | (x) = 1 / (1 + e^-x) |
|AllREDUCE | Reduce + Broadcast |
|KVCACHE | W_K W_V |
|MHA | Multi-Head Attention |
|MLA | Multi-head Latent Attention with [KV-Cache layout](https://docs.flashinfer.ai/tutorials/kv_layout.html#page-table-layout ) |
|PA | Paged Attention |
|FusedMoe | Mixture of Experts |
|QUANT | BF16/FP16 -> FP8/INT4 |
|RMSNORM | root mean square |
|LAYERNORM | x = (x - u) / (σ2 + ϵ) e*0.5 |
|ROPE | Rotary Position Embedding |
|GEMM | D=αAβB+C |
```
See the [Triton Comms Guide](docs/triton_comms.md) for usage details.

## Supported Operators

| **Operator** | **Description** | **Guide** |
|---|---|---|
| Attention (MHA, PA) | Multi-Head Attention, Paged Attention (decode & prefill), Unified Attention, chunked prefill, GQA/MQA | [Attention Guide](docs/attention_variants_guide.md) |
| MLA | Multi-head Latent Attention — standard decode, persistent decode, prefill, sparse MLA, fused ops | [MLA Guide](docs/mla_kernel_support_report.md) |
| Fused MOE | Mixture of Experts — A8W8, A16W8, FP8 block-scale, MXFP4, 2-stage MOE, topK routing | [MOE Guide](docs/moe_variants_guide.md) |
| GEMM | Matrix multiply (A8W8, A16W16, A4W4, batched), DeepGEMM, Triton FFN fusions, CSV-based tuning | [GEMM Guide](docs/gemm_variants_guide.md) |
| Quantization | BF16/FP16 to FP8/MXFP4/INT4, per-tensor/token/block strategies, fused quant ops, SmoothQuant | [Quantization Guide](docs/quantization_guide.md) |
| Normalization (RMSNorm, LayerNorm) | RMSNorm, LayerNorm, GroupNorm — fused add/quant variants, SmoothQuant, distributed fusion | [Normalization Guide](docs/normalization_guide.md) |
| RoPE | Rotary Position Embedding — SBHD/THD/2D/3D formats, NeoX & GPT-J styles, scaling methods | [RoPE Guide](docs/rope_guide.md) |
| KV-Cache | Paged/flash/MLA cache layouts, quantized cache (FP8/INT8), fused RoPE + cache write | [KV-Cache Guide](docs/kv_cache_guide.md) |
| Elementwise & Activations | SiLU/GELU/sigmoid/tanh, SwiGLU gates, fused activation + quantize, binary arithmetic (+−×÷) | [Elementwise Guide](docs/elementwise_activation_guide.md) |
| Communication (AllReduce) | GPU-initiated reduce-scatter and all-gather via [Iris](https://github.com/ROCm/iris) | [Triton Comms](docs/triton_comms.md) |

Each guide covers available variants, backend support (ASM / CK / Triton), Python API examples, and performance tuning advice.

Run operator tests with: `python3 op_tests/<test_file>.py` (e.g. `python3 op_tests/test_pa.py`)

## Additional Resources
- [Autotuning Pipeline](docs/autotuning_pipeline.md) — CSV-based kernel selection and tuning workflow
- [Container Setup (Non-root)](docs/aiter_container_nonroot_setup.md) — Running AITER in Docker without root
277 changes: 277 additions & 0 deletions docs/elementwise_activation_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# AITER Elementwise & Activation Operators Guide

This guide documents all element-wise arithmetic and activation operators available in AITER, including fused gating variants and activation-quantization fusions.

---

## Quick Reference

| Use Case | Recommended Operation | Backend | Why |
|----------|---------------------|---------|-----|
| **SwiGLU gate (LLM FFN)** | `silu_and_mul` | HIP/CK | Standard gated activation for LLaMA/Mistral |
| **SwiGLU + FP8 quantize** | `act_mul_and_fp8_group_quant` | Triton | Fused activation + quantize for inference |
| **SwiGLU + MXFP4 quantize** | `act_mul_and_mxfp4_quant` | Triton | Fused activation + MXFP4 for GFX950 |
| **GELU gate** | `gelu_and_mul` | HIP/CK | GPT-2/BERT-style gated activation |
| **Scaled SiLU (quantized input)** | `scaled_silu_and_mul` | HIP/CK | SiLU with input scale for quantized models |
| **Element-wise arithmetic** | `aiter.add/sub/mul/div` | HIP/CK | Optimized binary ops with broadcasting |
| **Sigmoid / Tanh** | `aiter.sigmoid/tanh` | HIP/CK | AMD-optimized fast math intrinsics |
| **Fused multiply-add** | `fused_mul_add` | Triton | `a * x + b` in one kernel |

---

## 1. Gated Activation Functions (SwiGLU / GeGLU)

The primary activation pattern in modern LLMs. The input tensor has shape `[M, 2*N]` and is split in half: one half is the gate (activation applied), the other is the value. The output is `activation(gate) * value` with shape `[M, N]`.

### Backend Support

| Activation | HIP/CK | Triton | Fused + FP8 Quant | Fused + MXFP4 Quant |
|-----------|:---:|:---:|:---:|:---:|
| **SiLU (SwiGLU)** | Yes | Yes | Yes | Yes |
| **GELU** | Yes | Yes | Yes | Yes |
| **GELU Tanh** | Yes | Yes | Yes | Yes |
| **Scaled SiLU** | Yes | - | - | - |

### Key API Functions

```python
import aiter

# SiLU-and-Mul (SwiGLU gate) — most common for LLaMA/Mistral/DeepSeek
out = torch.empty(M, N, dtype=dtype, device="cuda")
aiter.silu_and_mul(out, input) # input shape: [M, 2*N]

# Scaled SiLU-and-Mul (for quantized inference with input scale)
aiter.scaled_silu_and_mul(out, input, scale)

# GELU-and-Mul (GeGLU gate)
aiter.gelu_and_mul(out, input)

# GELU-Tanh-and-Mul (approximate GELU gate)
aiter.gelu_tanh_and_mul(out, input)
```

### Fused Activation + Quantization (Triton)

These fuse the gated activation with quantization in a single kernel, avoiding an extra memory round-trip:

```python
from aiter.ops.triton.activation import (
act_mul_and_fp8_group_quant, # Activation + FP8 group quantize
act_mul_and_mxfp4_quant, # Activation + MXFP4 block-scale quantize
)

# SiLU gate + FP8 group quantization
out, scales = act_mul_and_fp8_group_quant(
input, # [M, 2*N]
activation="silu", # "silu", "gelu", or "gelu_tanh"
group_size=128,
dtype_quant=torch.float8_e4m3fnuz,
)

# SiLU gate + MXFP4 block-scale quantization
out, scales = act_mul_and_mxfp4_quant(
input,
activation="silu",
scaling_mode="even", # Scale computation mode
shuffle=True, # Shuffle output layout
)
```

---

## 2. Unary Activations

### Sigmoid

Uses AMD fast math intrinsics (`__builtin_amdgcn_exp2f`, `__builtin_amdgcn_rcpf`) for optimized computation.

```python
import aiter

output = aiter.sigmoid(input)
```

### Tanh

```python
output = aiter.tanh(input)
```

### Supported Data Types

| Data Type | Sigmoid | Tanh |
|-----------|:---:|:---:|
| FP16 | Yes | Yes |
| BF16 | Yes | Yes |
| FP32 | Yes | Yes |

---

## 3. Element-wise Binary Arithmetic

Optimized binary operations with full broadcasting support. Uses JIT-compiled kernels that are specialized for the input dtype combination.

### Operations

```python
import aiter

# Out-of-place (return new tensor)
c = aiter.add(a, b) # c = a + b
c = aiter.sub(a, b) # c = a - b
c = aiter.mul(a, b) # c = a * b
c = aiter.div(a, b) # c = a / b

# In-place (modify first tensor)
aiter.add_(a, b) # a += b
aiter.sub_(a, b) # a -= b
aiter.mul_(a, b) # a *= b
aiter.div_(a, b) # a /= b
```

### Broadcasting

All binary ops support NumPy-style broadcasting:

```python
# Scalar broadcast
c = aiter.add(tensor_2d, scalar_tensor)

# Dimension broadcast
a = torch.randn(4, 1, device="cuda")
b = torch.randn(1, 8, device="cuda")
c = aiter.mul(a, b) # → shape [4, 8]
```

### Auto Type Promotion

Output dtype is automatically promoted via `torch.promote_types()`:

```python
a = torch.randn(4, 4, dtype=torch.float16, device="cuda")
b = torch.randn(4, 4, dtype=torch.float32, device="cuda")
c = aiter.add(a, b) # c.dtype == torch.float32
```

---

## 4. Fused Multiply-Add

Single-kernel element-wise `out = a * x + b` where `a` and `b` can be scalars or tensors:

```python
from aiter.ops.triton.fusions.fused_mul_add import fused_mul_add

# Tensor * tensor + tensor
out = torch.empty_like(x)
fused_mul_add(x, a_tensor, b_tensor, out)

# Scalar * tensor + scalar
fused_mul_add(x, 2.0, 1.0, out) # out = 2*x + 1
```

---

## 5. GEMM-Fused Activations

Activations can be fused directly into GEMM (matrix multiply) operations, eliminating the need for a separate activation kernel:

### GEMM + Activation

```python
from aiter.ops.triton.gemm.basic.gemm_a16w16 import gemm_a16w16

# Matrix multiply with post-GEMM activation
y = gemm_a16w16(x, weight, bias=None, dtype=torch.bfloat16,
activation="silu") # Applied after matmul
```

### GEMM + Gated Activation (SwiGLU FFN)

```python
from aiter.ops.triton.gemm.basic.gemm_a16w16 import gemm_a16w16_gated

# Gated matmul: splits output, applies activation to gate half
y = gemm_a16w16_gated(x, weight, dtype=torch.bfloat16,
activation="silu")
```

### Feed-Forward Blocks

Complete FFN blocks with gating built in:

```python
from aiter.ops.triton.gemm.feed_forward.ff_a16w16 import (
ff_a16w16_gated, # SwiGLU: x → up_proj → gate*value → down_proj
ff_a16w16_nogate, # Standard: x → up_proj → activation → down_proj
)
```

---

## 6. Triton Activation Kernels

Available activation functions in Triton kernels (used internally by fused operations):

| Function | Formula | Usage |
|----------|---------|-------|
| `_silu(x)` | `x * sigmoid(x)` | SwiGLU gates |
| `_gelu(x)` | `0.5 * x * (1 + erf(x / sqrt(2)))` | Standard GELU |
| `_gelu_tanh(x)` | `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))` | Approximate GELU |
| `_tanh(x)` | `2 * sigmoid(2x) - 1` | Tanh |
| `_relu(x)` | `max(0, x)` | ReLU |

---

## 7. Decision Tree

```
Need activation?
├── Gated FFN (SwiGLU/GeGLU)?
│ ├── Standard inference → aiter.silu_and_mul()
│ ├── With input scaling → aiter.scaled_silu_and_mul()
│ ├── Need FP8 output → act_mul_and_fp8_group_quant()
│ ├── Need MXFP4 output → act_mul_and_mxfp4_quant()
│ └── Fused with GEMM → gemm_a16w16_gated() or ff_a16w16_gated()
├── Standalone activation?
│ ├── Sigmoid → aiter.sigmoid()
│ └── Tanh → aiter.tanh()
├── Element-wise arithmetic?
│ ├── Standard ops → aiter.add/sub/mul/div()
│ ├── In-place → aiter.add_/sub_/mul_/div_()
│ └── Fused a*x+b → fused_mul_add()
└── GELU variant?
├── Standard → aiter.gelu_and_mul()
└── Tanh approx → aiter.gelu_tanh_and_mul()
```

---

## 8. Source Files

| Component | Path |
|-----------|------|
| Gated activation API | `aiter/ops/activation.py` |
| Binary/unary ops API | `aiter/ops/aiter_operator.py` |
| Triton activation wrappers | `aiter/ops/triton/activation.py` |
| Triton activation kernels | `aiter/ops/triton/_triton_kernels/activation.py` |
| Triton fused mul-add | `aiter/ops/triton/fusions/fused_mul_add.py` |
| Triton GEMM + activation | `aiter/ops/triton/gemm/basic/gemm_a16w16.py` |
| Triton gated FFN | `aiter/ops/triton/gemm/feed_forward/ff_a16w16.py` |
| HIP activation kernels | `csrc/kernels/activation_kernels.cu` |
| HIP unary operators | `csrc/kernels/unary_operator.cu` |
| HIP binary operators | `csrc/kernels/binary_operator.cu` |

---

## 9. Test Files

| Test | Path |
|------|------|
| Activation (SiLU, scaled) | `op_tests/test_activation.py` |
| Triton activation + quant | `op_tests/triton_tests/test_activation.py` |
| Add | `op_tests/test_aiter_add.py` |
| Add in-place | `op_tests/test_aiter_addInp.py` |
| Sigmoid | `op_tests/test_aiter_sigmoid.py` |
| Fused mul-add | `op_tests/triton_tests/fusions/test_fused_mul_add.py` |
Loading
Loading