Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ See the [Triton Comms Guide](docs/triton_comms.md) for usage details.
| Sampling | Greedy, random, mixed, top-k, top-p token sampling for LLM generation | [Sampling Guide](docs/sampling_guide.md) |
| Top-K | Top-k selection — MOE routing (grouped, biased), radix/bitonic sort, fused softmax+topk | [Top-K Guide](docs/topk_guide.md) |
| Communication (AllReduce) | Custom all-reduce, quick all-reduce, Iris reduce-scatter/all-gather | [Distributed Guide](docs/distributed_guide.md) |
| Causal Conv1D | Causal convolution for Mamba/SSM models — prefill, decode, fused QKV split, speculative decoding | [Causal Conv1D Guide](docs/causal_conv1d_guide.md) |
| Gated Delta Net | Gated delta rule recurrence — fused recurrent, chunk-based, sigmoid gating, GVA support | [GDN Guide](docs/gated_delta_net_guide.md) |
| Grouped GEMM | GMM (Triton) and DeepGEMM (CK) — MoE expert routing, variable-length grouped GEMM | [Grouped GEMM Guide](docs/grouped_gemm_guide.md) |

Each guide covers available variants, backend support (ASM / CK / Triton), Python API examples, and performance tuning advice.

Expand All @@ -78,6 +81,8 @@ Run operator tests with: `python3 op_tests/<test_file>.py` (e.g. `python3 op_tes
| [JIT Compilation System](docs/jit_system_guide.md) | `@compile_ops` decorator, module config, build flow, cache, GPU detection |
| [GEMM Tuning & Gradlib](docs/gemm_tuning_guide.md) | CSV-based kernel dispatch, hipBLASLt/ASM tuning, gradlib framework |
| [Distributed Infrastructure](docs/distributed_guide.md) | Tensor parallelism, custom/quick all-reduce, Iris comms, shared memory broadcast |
| [Weight Shuffle & Preshuffle](docs/weight_shuffle_guide.md) | Weight layout transforms for CK/ASM/Triton GEMM, FP8/FP4 preshuffle |
| [BERT Padding & Variable-Length](docs/bert_padding_guide.md) | Pad/unpad utilities, variable-length attention, cumulative sequence lengths |

## Additional Resources
- [Triton-based Communication (Iris)](docs/triton_comms.md) — GPU-initiated reduce-scatter and all-gather
Expand Down
186 changes: 186 additions & 0 deletions docs/bert_padding_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# AITER BERT Padding & Variable-Length Sequence Guide

This guide documents the padding/unpadding utilities and variable-length sequence handling in AITER, enabling efficient attention computation on batches with different sequence lengths.

---

## Quick Reference

| Use Case | Function | Description |
|----------|---------|-------------|
| **Remove padding** | `unpad_input(hidden_states, attention_mask)` | Batch → packed sequences |
| **Restore padding** | `pad_input(hidden_states, indices, batch, seqlen)` | Packed → batch format |
| **Concatenated sequences** | `unpad_input_for_concatenated_sequences(...)` | SFT-style concatenated samples |
| **Variable-length attention** | `flash_attn_varlen_func(...)` | Flash attention on packed inputs |

---

## 1. Padding / Unpadding Utilities

### `unpad_input`

Removes padding tokens from batch-formatted tensors to create packed sequences:

```python
from aiter.bert_padding import unpad_input

hidden_states_unpad, indices, cu_seqlens, max_seqlen, seqused = unpad_input(
hidden_states, # (batch, seqlen, ...) — padded input
attention_mask, # (batch, seqlen) — 1=valid, 0=padding
unused_mask=None, # (batch, seqlen) — 1=allocated but unused (optional)
)
# hidden_states_unpad: (total_nnz, ...) — packed valid tokens
# indices: (total_nnz,) — indices into flattened input
# cu_seqlens: (batch+1,) — cumulative sequence lengths
# max_seqlen: int — longest sequence
# seqused: (batch,) — tokens selected per batch element
```

### `pad_input`

Restores packed sequences to padded batch format (inverse of `unpad_input`):

```python
from aiter.bert_padding import pad_input

hidden_states = pad_input(
hidden_states_unpad, # (total_nnz, ...) — packed tokens
indices, # (total_nnz,) — from unpad_input
batch, # int — batch size
seqlen, # int — max sequence length
)
# hidden_states: (batch, seqlen, ...) — zero-padded output
```

### `unpad_input_for_concatenated_sequences`

For supervised fine-tuning where multiple short samples are concatenated:

```python
from aiter.bert_padding import unpad_input_for_concatenated_sequences

hidden_states_unpad, indices, cu_seqlens, max_seqlen = unpad_input_for_concatenated_sequences(
hidden_states, # (batch, seqlen, ...)
attention_mask_in_length, # (batch, seqlen) — nonzero=length of concat'd sequence
)
```

---

## 2. Variable-Length Flash Attention

After unpadding, use variable-length flash attention:

```python
import aiter

out_unpad = aiter.flash_attn_varlen_func(
q_unpad, # (total_q, nheads, headdim)
k_unpad, # (total_k, nheads_k, headdim)
v_unpad, # (total_k, nheads_k, headdim_v)
cu_seqlens_q, # (batch+1,) int32
cu_seqlens_k, # (batch+1,) int32
max_seqlen_q, # int
max_seqlen_k, # int
causal=True,
softmax_scale=None,
# Optional physical padding support:
cu_seqlens_q_padded=None,
cu_seqlens_k_padded=None,
)
```

**Backends:** CK (primary), FMHA v3 (bf16, hdim 128/192), Triton

---

## 3. End-to-End Example

```python
import torch
import aiter
from aiter.bert_padding import unpad_input, pad_input
from aiter.test_mha_common import generate_qkv

batch_size, seqlen_q, seqlen_k = 4, 512, 512
nheads, d = 32, 128

# Padded inputs
q = torch.randn(batch_size, seqlen_q, nheads, d, device="cuda", dtype=torch.bfloat16)
k = torch.randn(batch_size, seqlen_k, nheads, d, device="cuda", dtype=torch.bfloat16)
v = torch.randn(batch_size, seqlen_k, nheads, d, device="cuda", dtype=torch.bfloat16)

# Random padding masks (actual sequence lengths vary)
query_padding_mask = torch.ones(batch_size, seqlen_q, device="cuda", dtype=torch.bool)
key_padding_mask = torch.ones(batch_size, seqlen_k, device="cuda", dtype=torch.bool)
query_padding_mask[0, 300:] = False # First sequence is 300 tokens
key_padding_mask[0, 300:] = False

# Unpad
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)
k_unpad, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)

# Run attention on packed sequences (no wasted computation on padding)
out_unpad = aiter.flash_attn_varlen_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
causal=True,
)

# Restore to padded format
out = pad_input(out_unpad, indices_q, batch_size, seqlen_q)
```

---

## 4. Implementation Details

### Custom Autograd Functions

The module uses optimized `torch.gather`/`torch.scatter` instead of boolean indexing:

| Function | Forward | Backward |
|----------|---------|----------|
| `IndexFirstAxis` | `torch.gather` | `torch.scatter` |
| `IndexPutFirstAxis` | `torch.scatter` | `torch.gather` |
| `IndexFirstAxisResidual` | Index + residual pass-through | scatter-add |

### Tensor Formats

| Format | Shape | Description |
|--------|-------|-------------|
| **Padded** | `(batch, seqlen, nheads, hdim)` | Standard batch format |
| **Unpadded** | `(total_tokens, nheads, hdim)` | Packed valid tokens |
| **cu_seqlens** | `(batch+1,)` int32 | Cumulative lengths, starts at 0 |

---

## 5. Supported Data Types

| dtype | Notes |
|-------|-------|
| float16 | Fully supported |
| bfloat16 | Preferred for FMHA v3 |
| float8_e4m3fn/fnuz | With descaling parameters |

---

## 6. Source Files

| Component | Path |
|---|---|
| Padding utilities | `aiter/bert_padding.py` |
| Variable-length attention API | `aiter/ops/mha.py` |
| Test utilities (generate_qkv) | `aiter/test_mha_common.py` |
| CK attention kernels | `csrc/include/torch/mha_varlen_fwd.h` |

---

## 7. Test Files

| Test | Path |
|------|------|
| Variable-length MHA | `op_tests/test_mha_varlen.py` |
| Variable-length MHA (FP8) | `op_tests/test_mha_varlen_fp8.py` |
182 changes: 182 additions & 0 deletions docs/causal_conv1d_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# AITER Causal Conv1D Operators Guide

This guide documents the Causal Conv1D operators in AITER, used in Mamba-style state-space models and gated architectures for sequence modeling with causal (left-only) convolution.

---

## Quick Reference

| Use Case | Recommended Operation | Backend | Why |
|----------|---------------------|---------|-----|
| **Prefill (variable-length)** | `causal_conv1d_fn` | Triton | Continuous batching, ragged sequences |
| **Decode (single/multi-token)** | `causal_conv1d_update` | Triton | State caching, speculative decoding |
| **Fused conv + QKV split (prefill)** | `causal_conv1d_fn_split_qkv` | Triton | Avoids intermediate allocation |
| **Fused conv + QKV split (decode)** | `causal_conv1d_update_split_qkv` | Triton | Gluon-optimized variants |

---

## 1. Core Operations

### `causal_conv1d_fn` (Forward / Prefill)

Variable-length forward pass with continuous batching support:

```python
from aiter.ops.triton.causal_conv1d import causal_conv1d_fn

out = causal_conv1d_fn(
x, # (dim, cu_seqlen) — concatenated tokens, channel-last
weight, # (dim, width) — convolution weights
bias, # (dim,) or None
conv_states=conv_states, # (num_cache_lines, dim, width-1) — state cache
query_start_loc=query_start_loc, # (batch+1,) int32 — cumulative seq lengths
seq_lens_cpu=seq_lens_cpu, # List[int] — per-sequence lengths (CPU)
cache_indices=cache_indices, # (batch,) int32 — maps seq to cache slot
has_initial_state=has_initial_state, # (batch,) bool — load existing state
activation="silu", # "silu", "swish", or None
pad_slot_id=-1, # sentinel for padded sequences
)
```

**Features:**
- Continuous batching via `cache_indices` for cache slot indirection
- Variable-length sequences via `query_start_loc` (ragged batching)
- Automatic state save/restore for stateful processing
- Kernel widths: 2, 3, 4, 5

### `causal_conv1d_update` (Decode)

Single or multi-token update for autoregressive generation:

```python
from aiter.ops.triton.causal_conv1d import causal_conv1d_update

out = causal_conv1d_update(
x, # (batch, dim, seqlen) or (batch, dim)
conv_state, # (num_cache_lines, dim, state_len)
weight, # (dim, width)
bias=None,
activation="silu",
conv_state_indices=indices, # (batch,) int32 — cache line mapping
num_accepted_tokens=accepted, # (batch,) — for speculative decoding
intermediate_conv_window=window, # saves intermediate states
pad_slot_id=-1,
)
```

**Features:**
- Single token decode and multi-token speculative decoding
- `num_accepted_tokens` for token rollback in speculative decoding
- `intermediate_conv_window` for saving per-step window states
- Kernel widths: 2, 3, 4

---

## 2. Fused QKV Split Variants

For Gated Delta Net models, fused operations perform causal conv1d **and** split output into Q/K/V, avoiding intermediate tensor allocation.

### Prefill

```python
from aiter.ops.triton._triton_kernels.gated_delta_rule.prefill.causal_conv1d_fwd_split_qkv import (
causal_conv1d_fn_split_qkv
)

q, k, v = causal_conv1d_fn_split_qkv(
x, # (dim, cu_seqlen) where dim = 2*k_dim + v_dim
weight, bias, conv_states,
query_start_loc, seq_lens_cpu,
k_dim=128, # query and key dimension
v_dim=128, # value dimension
activation="silu",
)
# q: (cu_seqlen, k_dim), k: (cu_seqlen, k_dim), v: (cu_seqlen, v_dim)
```

### Decode

```python
from aiter.ops.triton._triton_kernels.gated_delta_rule.decode.causal_conv1d_split_qkv import (
causal_conv1d_update_split_qkv
)

q, k, v = causal_conv1d_update_split_qkv(
x, # (batch, dim, seqlen) where dim = 2*key_dim + value_dim
conv_state, weight,
key_dim=128, value_dim=128,
activation="silu",
use_gluon=True, # Gluon kernel (default)
use_gluon_v2=False, # Optimized Gluon v2
)
# q, k: (batch, key_dim, seqlen), v: (batch, value_dim, seqlen)
```

**Decode kernel variants:**
1. Standard Triton
2. Optimized v2 (multi-CU utilization)
3. Gluon (experimental Triton frontend)
4. Gluon v2 (eliminates tuple operations for better register allocation)

---

## 3. Layout Requirements

All operators require **channel-last** input layout (`x.stride(0) == 1`) for coalesced memory access.

| Tensor | Shape | Notes |
|--------|-------|-------|
| `x` (prefill) | `(dim, cu_seqlen)` | Channel-last, concatenated sequences |
| `x` (decode) | `(batch, dim, seqlen)` | `seqlen=1` for single-token decode |
| `weight` | `(dim, width)` | Shared across all sequences |
| `conv_states` | `(num_cache_lines, dim, width-1)` | Indexed via `cache_indices` |

---

## 4. Supported Data Types

| Input | Tolerance |
|-------|-----------|
| float32 | rtol=3e-4, atol=1e-3 |
| bfloat16 | rtol=1e-2, atol=5e-2 |

---

## 5. Backend Support

All implementations are **Triton-only** — no CUDA/HIP/CK/ASM kernels.

| Operator | Triton | Gluon |
|----------|:------:|:-----:|
| `causal_conv1d_fn` | Yes | — |
| `causal_conv1d_update` | Yes | — |
| `causal_conv1d_fn_split_qkv` | Yes | — |
| `causal_conv1d_update_split_qkv` | Yes | Yes (v1, v2) |

---

## 6. Performance Notes

- **Block sizes:** BLOCK_M=8 (tokens), BLOCK_N=256 (features), 2-stage software pipelining
- **State management:** Uses `.ca` cache modifier for prior token loads (L2 cache hints)
- **Gluon v2:** Eliminates tuple operations in hot loops for better register allocation and multi-CU utilization
- **In-place output:** Update kernel writes directly to input tensor `x`

---

## 7. Source Files

| Component | Path |
|---|---|
| Core Python API | `aiter/ops/triton/causal_conv1d.py` |
| Triton kernels | `aiter/ops/triton/_triton_kernels/causal_conv1d.py` |
| Fused prefill split QKV | `aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/causal_conv1d_fwd_split_qkv.py` |
| Fused decode split QKV | `aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/causal_conv1d_split_qkv.py` |

---

## 8. Test Files

| Test | Path |
|------|------|
| Core causal conv1d | `op_tests/triton_tests/test_causal_conv1d.py` |
Loading
Loading