Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
43b3a0b
optimize attention autotune and test
sustcsonglin Jun 21, 2025
07b4405
optimize test and autotune
sustcsonglin Jun 21, 2025
2439891
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 22, 2025
b4f8505
Merge branch 'main' into fast_test_autotune
zhiyuan1i Jun 23, 2025
bd01a5a
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 23, 2025
acc38de
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 23, 2025
542a4bd
Use type hints for params in test files
yzhangcs Jun 23, 2025
b29880a
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 24, 2025
c093386
Update attn test cases
yzhangcs Jun 24, 2025
e389ef0
Fix block size to ensure pow of 2
yzhangcs Jun 24, 2025
ff63caf
Fix compiling bugs
yzhangcs Jun 24, 2025
7a1cc65
Update cumsum.py
yzhangcs Jun 24, 2025
4e77985
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 24, 2025
cb4afaf
[GLA] Refactor tests
yzhangcs Jun 24, 2025
bc39bf9
Improve modeling tests
yzhangcs Jun 24, 2025
a5bc8fe
Refactor modeling tests (#482)
yzhangcs Jun 24, 2025
d436cbd
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 24, 2025
548b5ec
[CI] Refractor Triton ci (#484)
zhiyuan1i Jun 25, 2025
7b1f9ba
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 25, 2025
ab2ad98
fixup! [CI] Refractor Triton ci (#484)
zhiyuan1i Jun 25, 2025
cc611c0
refractor gpu ci
zhiyuan1i Jun 25, 2025
21fbe0a
fixup
zhiyuan1i Jun 25, 2025
c658bd7
fixup
zhiyuan1i Jun 25, 2025
0e0c5a8
fixup conda paths
zhiyuan1i Jun 25, 2025
60b4c10
fix setup conda1
zhiyuan1i Jun 25, 2025
9d33a1a
install nvcc
zhiyuan1i Jun 25, 2025
5043ea3
f
zhiyuan1i Jun 25, 2025
a609b21
fix conda activate
zhiyuan1i Jun 25, 2025
41a9554
f
zhiyuan1i Jun 25, 2025
3484282
f
zhiyuan1i Jun 25, 2025
2586d7d
f
zhiyuan1i Jun 25, 2025
c1eb9d4
[GLA] Fix dgv bugs & refactor tests (#486)
yzhangcs Jun 25, 2025
e2fbc5d
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 25, 2025
92ef7d9
revert CI
zhiyuan1i Jun 25, 2025
c26aab2
[CI] Refract GPU CIs (#487)
zhiyuan1i Jun 25, 2025
9ef74db
Merge branch 'main' into fast_test_autotune
zhiyuan1i Jun 25, 2025
bbe05cf
[CI] Refract GPU CIs (#487)
zhiyuan1i Jun 25, 2025
fc55e13
only test h100 on main branch
zhiyuan1i Jun 25, 2025
2e81812
Merge branch 'main' into fast_test_autotune
zhiyuan1i Jun 25, 2025
a334781
reduce rwkv7 tests
zhiyuan1i Jun 25, 2025
c187b29
reduce rwkv6 tests
zhiyuan1i Jun 25, 2025
61cd16e
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 26, 2025
3797d79
split op test and model test
zhiyuan1i Jun 26, 2025
7020b26
Merge non-varlen and varlen tests
zhiyuan1i Jun 26, 2025
d98c1e2
reduce dplr tests
zhiyuan1i Jun 26, 2025
51a11dc
fix lint
zhiyuan1i Jun 26, 2025
4c23679
add a new pt2.7 env
zhiyuan1i Jun 26, 2025
1ac179c
fix lint
zhiyuan1i Jun 26, 2025
63b33fd
more general ci
zhiyuan1i Jun 26, 2025
431d7dc
fix
zhiyuan1i Jun 26, 2025
12c4b48
fix test-models
zhiyuan1i Jun 26, 2025
37488f7
Refactor GLA and GSA tests to use random tensors for gradients and im…
yzhangcs Jun 26, 2025
1edb77a
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 26, 2025
145ed19
support fused_addcmul in python3.10
zhiyuan1i Jun 26, 2025
26b95d8
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 26, 2025
90918b9
Merge branch 'main' into fast_test_autotune
zhiyuan1i Jun 26, 2025
0a8cd55
Refactor test assertions to remove leading spaces in assert_close calls
yzhangcs Jun 26, 2025
de0d3ef
Refactor test parametrization for retention ops
yzhangcs Jun 26, 2025
0d54774
Delete tests for BC
yzhangcs Jun 26, 2025
e510836
Refactor test parametrization for linear attention and simple GLA tests
yzhangcs Jun 26, 2025
1d0ce78
Refactor test parametrization for Comba and Path Attention tests
yzhangcs Jun 26, 2025
4c2dbfb
Refactor test parametrization for delta and path attention tests
yzhangcs Jun 26, 2025
bfb9a53
Update tests
yzhangcs Jun 26, 2025
4f26f23
Refactor test parametrization across multiple test files to streamlin…
yzhangcs Jun 26, 2025
c306b32
Add skip conditions for tests in test_titans.py and test_ttt.py to ha…
yzhangcs Jun 26, 2025
ba05dba
Refactor test parametrization in test_rwkv6
yzhangcs Jun 26, 2025
80b93b1
Add skip condition for test_parallel_varlen if flash-attn is not inst…
yzhangcs Jun 26, 2025
ebcd751
Merge branch 'main' into fast_test_autotune
yzhangcs Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fla/ops/attn/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
for num_warps in [2, 4] + ([8] if check_shared_mem('hopper') else [])
for num_stages in [2, 3, 4, 5]
],
key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'],
Expand Down Expand Up @@ -177,7 +177,7 @@ def parallel_attn_bwd_kernel_preprocess(
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
for num_warps in [2, 4] + ([8] if check_shared_mem('hopper') else [])
for num_stages in [2, 3, 4, 5]
],
key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'],
Expand Down Expand Up @@ -319,7 +319,7 @@ def parallel_attn_bwd_kernel_dq(
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
for num_warps in [2, 4] + ([8] if check_shared_mem('hopper') else [])
for num_stages in [2, 3, 4, 5]
],
key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'],
Expand Down
8 changes: 3 additions & 5 deletions fla/ops/common/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
})
@triton.autotune(
configs=[
triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
for BK in BKV_LIST
for BV in BKV_LIST
for num_warps in NUM_WARPS
for num_stages in [2, 3, 4]
triton.Config({'BK': 128, 'BV': 128}, num_warps=8, num_stages=3),
triton.Config({'BK': 64, 'BV': 64}, num_warps=4, num_stages=3),
triton.Config({'BK': 32, 'BV': 32}, num_warps=2, num_stages=3),
],
key=['H', 'K', 'V', 'BT'],
)
Expand Down
12 changes: 6 additions & 6 deletions fla/ops/common/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
],
key=['BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_GK', 'USE_GV'],
)
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=['B', 'T'])
def fused_recurrent_fwd_kernel(
q,
k,
Expand All @@ -38,8 +38,8 @@ def fused_recurrent_fwd_kernel(
ht,
cu_seqlens,
scale,
B,
T,
B: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand Down Expand Up @@ -137,7 +137,7 @@ def fused_recurrent_fwd_kernel(
],
key=['BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_GK', 'USE_GV'],
)
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=['B', 'T'])
def fused_recurrent_bwd_kernel(
q,
k,
Expand All @@ -156,8 +156,8 @@ def fused_recurrent_bwd_kernel(
dh0,
cu_seqlens,
scale,
B,
T,
B: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand Down Expand Up @@ -320,7 +320,7 @@ def fused_recurrent_fwd(
):
B, T, H, K, V = *k.shape, v.shape[-1]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = min(K, 64), min(V, 64)
BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)

h0 = initial_state
Expand Down Expand Up @@ -377,7 +377,7 @@ def fused_recurrent_bwd(
B, T, H, K, V = *k.shape, v.shape[-1]
N = B if cu_seqlens is None else len(cu_seqlens) - 1

BK, BV = min(K, 64), min(V, 64)
BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)

h0 = initial_state
Expand Down
15 changes: 3 additions & 12 deletions fla/ops/utils/cumsum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import warnings
from typing import Optional

import torch
Expand Down Expand Up @@ -165,7 +164,7 @@ def chunk_global_cumsum_scalar_kernel(
b_z = tl.zeros([], dtype=tl.float32)
NT = tl.cdiv(T, BT)
for i_c in range(NT):
i_t = NT-1-i_c if REVERSE else i_c
i_t = NT - 1 - i_c if REVERSE else i_c
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
Expand Down Expand Up @@ -232,7 +231,7 @@ def chunk_global_cumsum_vector_kernel(
b_z = tl.zeros([BS], dtype=tl.float32)
NT = tl.cdiv(T, BT)
for i_c in range(NT):
i_t = NT-1-i_c if REVERSE else i_c
i_t = NT - 1 - i_c if REVERSE else i_c
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
Expand All @@ -245,8 +244,7 @@ def chunk_global_cumsum_vector_kernel(
if HAS_SCALE:
b_c *= scale
tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1))
if i_c >= 0:
b_z += tl.sum(b_s, 0)
b_z += tl.sum(b_s, 0)


def chunk_local_cumsum_scalar(
Expand Down Expand Up @@ -437,13 +435,6 @@ def chunk_local_cumsum(
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when `head_first=False` was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
)
if cu_seqlens is not None:
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
if len(g.shape) == 3:
Expand Down
52 changes: 37 additions & 15 deletions tests/models/test_modeling_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,47 @@
# ===================================================================================
# Test for Modeling (Forward/Backward Pass)
# ===================================================================================
@pytest.mark.parametrize("L", [4])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [1024])
@pytest.mark.parametrize("H", [4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("use_l2warp", [True, False])
def test_modeling(L, B, T, H, D, dtype, use_l2warp):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype', 'use_l2warp'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
for test in [
(4, 4, 1024, 4, 64, True, torch.bfloat16),
(4, 4, 1024, 4, 64, False, torch.bfloat16),
(4, 4, 1024, 4, 128, False, torch.bfloat16),
]
]
)
def test_modeling(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
use_l2warp: bool,
):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
run_test_model_forward_backward(L, B, T, H, D, ABCConfig, dtype, use_l2warp)


# ===================================================================================
# Test for Generation
# ===================================================================================
@pytest.mark.parametrize("L", [2])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [4000])
@pytest.mark.parametrize("H", [8])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_generation(L, B, T, H, D, dtype):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, ABCConfig, dtype)
52 changes: 37 additions & 15 deletions tests/models/test_modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,47 @@
# ===================================================================================
# Test for Modeling (Forward/Backward Pass)
# ===================================================================================
@pytest.mark.parametrize("L", [4])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [1024])
@pytest.mark.parametrize("H", [4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("use_l2warp", [True, False])
def test_modeling(L, B, T, H, D, dtype, use_l2warp):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
for test in [
(4, 4, 1024, 4, 64, True, torch.bfloat16),
(4, 4, 1024, 4, 64, False, torch.bfloat16),
(4, 4, 1024, 4, 128, False, torch.bfloat16),
]
]
)
def test_modeling(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
use_l2warp: bool,
):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
run_test_model_forward_backward(L, B, T, H, D, BitNetConfig, dtype, use_l2warp)


# ===================================================================================
# Test for Generation
# ===================================================================================
@pytest.mark.parametrize("L", [2])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [4000])
@pytest.mark.parametrize("H", [8])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_generation(L, B, T, H, D, dtype):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, BitNetConfig, dtype)
52 changes: 37 additions & 15 deletions tests/models/test_modeling_comba.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,47 @@
# ===================================================================================
# Test for Modeling (Forward/Backward Pass)
# ===================================================================================
@pytest.mark.parametrize("L", [4])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [1024])
@pytest.mark.parametrize("H", [4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("use_l2warp", [True, False])
def test_modeling(L, B, T, H, D, dtype, use_l2warp):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
for test in [
(4, 4, 1024, 4, 64, True, torch.bfloat16),
(4, 4, 1024, 4, 64, False, torch.bfloat16),
(4, 4, 1024, 4, 128, False, torch.bfloat16),
]
]
)
def test_modeling(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
use_l2warp: bool,
):
run_test_model_forward_backward(L, B, T, H, D, CombaConfig, dtype, use_l2warp)


# ===================================================================================
# Test for Generation
# ===================================================================================
@pytest.mark.parametrize("L", [2])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [4000])
@pytest.mark.parametrize("H", [8])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_generation(L, B, T, H, D, dtype):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, CombaConfig, dtype)
52 changes: 37 additions & 15 deletions tests/models/test_modeling_deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,47 @@
# ===================================================================================
# Test for Modeling (Forward/Backward Pass)
# ===================================================================================
@pytest.mark.parametrize("L", [4])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [1024])
@pytest.mark.parametrize("H", [4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("use_l2warp", [True, False])
def test_modeling(L, B, T, H, D, dtype, use_l2warp):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
for test in [
(4, 4, 1024, 4, 64, True, torch.bfloat16),
(4, 4, 1024, 4, 64, False, torch.bfloat16),
(4, 4, 1024, 4, 128, False, torch.bfloat16),
]
]
)
def test_modeling(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
use_l2warp: bool,
):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
run_test_model_forward_backward(L, B, T, H, D, DeltaNetConfig, dtype, use_l2warp)


# ===================================================================================
# Test for Generation
# ===================================================================================
@pytest.mark.parametrize("L", [2])
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("T", [4000])
@pytest.mark.parametrize("H", [8])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_generation(L, B, T, H, D, dtype):
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, DeltaNetConfig, dtype)
Loading
Loading