Skip to content

Commit b9d5d8d

Browse files
authored
[Dev] Update linear attention examples to enhance performance on Hopper GPUs (#621)
* Tune linear attention examples on H100 * Add retnet fwd kernel * fix lint
1 parent 3e27fb0 commit b9d5d8d

File tree

4 files changed

+164
-222
lines changed

4 files changed

+164
-222
lines changed

examples/linear_attention/example_linear_attn_bwd.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
1111

1212

13-
@tl.jit(out_idx=[4, 5, 6])
13+
@tl.jit(
14+
out_idx=[4, 5, 6],
15+
pass_configs={
16+
"tl.disable_tma_lower": True,
17+
"tl.disable_warp_specialized": True
18+
})
1419
def chunk_linear_attn_bwd_kernel(
1520
B,
1621
S,
@@ -26,21 +31,21 @@ def chunk_linear_attn_bwd_kernel(
2631
accum_dtype = 'float'
2732

2833
chunk_size = 64
29-
BK = BV = 64
34+
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
3035
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
3136
NK = tl.cdiv(DK, BK)
3237
NV = tl.cdiv(DV, BV)
3338
NT = tl.cdiv(S, chunk_size)
3439

3540
@T.prim_func
36-
def main(
37-
Q: T.Tensor([B, S, H, DK], dtype),
38-
K: T.Tensor([B, S, H, DK], dtype),
39-
V: T.Tensor([B, S, H, DV], dtype),
40-
dO: T.Tensor([B, S, H, DV], dtype),
41-
dQ: T.Tensor([NV, B, S, H, DK], dtype),
42-
dK: T.Tensor([NV, B, S, H, DK], dtype),
43-
dV: T.Tensor([NK, B, S, H, DV], dtype),
41+
def chunk_linear_attn_bwd(
42+
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
43+
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
44+
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
45+
dO: T.Tensor([B, S, H, DV], dtype), # type: ignore
46+
dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
47+
dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
48+
dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
4449
):
4550
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
4651
i_b = i_bh // H
@@ -71,6 +76,7 @@ def main(
7176
h_shared: tl.layout.make_swizzled_layout(h_shared),
7277
dh_shared: tl.layout.make_swizzled_layout(dh_shared)
7378
})
79+
T.use_swizzle(10)
7480

7581
# Calculate dQ
7682
for i in T.Pipelined(0, NT, num_stages=1):
@@ -107,7 +113,6 @@ def main(
107113
T.copy(
108114
dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
109115
i_v * BV:(i_v + 1) * BV], do)
110-
T.copy(dh, dh_shared)
111116

112117
# Calculate dk
113118
T.gemm(
@@ -116,6 +121,7 @@ def main(
116121
for row, col in T.Parallel(chunk_size, chunk_size):
117122
ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0)
118123
T.gemm(ds_shared, q, dk, clear_accum=True)
124+
T.copy(dh, dh_shared)
119125
T.gemm(v, dh_shared, dk, transpose_B=True)
120126

121127
# Calculate dv
@@ -135,7 +141,7 @@ def main(
135141
dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
136142
i_v * BV:(i_v + 1) * BV])
137143

138-
return main
144+
return chunk_linear_attn_bwd
139145

140146

141147
def postprocess(dQ, dK, dV):
@@ -148,8 +154,8 @@ def postprocess(dQ, dK, dV):
148154
def main():
149155
parser = argparse.ArgumentParser()
150156
parser.add_argument('--B', type=int, default=8, help='Batch size')
151-
parser.add_argument('--S', type=int, default=2048, help='Seq len')
152-
parser.add_argument('--H', type=int, default=64, help='Num heads')
157+
parser.add_argument('--S', type=int, default=4096, help='Seq len')
158+
parser.add_argument('--H', type=int, default=32, help='Num heads')
153159
parser.add_argument('--D', type=int, default=256, help='Head dim')
154160
args = parser.parse_args()
155161
B, S, H, D = args.B, args.S, args.H, args.D
@@ -161,15 +167,15 @@ def main():
161167

162168
kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D)
163169
dq, dk, dv = postprocess(*kernel(q, k, v, do))
164-
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
170+
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
165171
o_ref.backward(do, retain_graph=True)
166172
if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad):
167173
print('Passed all tests!✅')
168174
else:
169175
print('Failed some tests!❌')
170176
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100)
171177
q.grad = k.grad = v.grad = None
172-
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
178+
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
173179
t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100)
174180
print(f'Triton latency: {t1:.3f} ms')
175181
print(f'TileLang latency: {t2:.3f} ms')

examples/linear_attention/example_linear_attn_fwd.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
1111

1212

13-
@tl.jit(out_idx=[3, 4])
13+
@tl.jit(
14+
out_idx=[3, 4],
15+
pass_configs={
16+
"tl.disable_tma_lower": True,
17+
"tl.disable_warp_specialized": True
18+
})
1419
def chunk_linear_attn_fwd_kernel(
1520
B,
1621
S,
@@ -26,16 +31,19 @@ def chunk_linear_attn_fwd_kernel(
2631
accum_dtype = 'float'
2732

2833
chunk_size = 64
29-
BK = BV = 64
34+
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
3035
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
3136
NK = tl.cdiv(DK, BK)
3237
NV = tl.cdiv(DV, BV)
3338
NT = tl.cdiv(S, chunk_size)
3439

3540
@T.prim_func
36-
def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype),
37-
V: T.Tensor([B, S, H, DV], dtype), O: T.Tensor([NK, B, S, H, DV], dtype),
38-
final_state: T.Tensor([B, H, DK, DV], accum_dtype)):
41+
def chunk_linear_attn_fwd(
42+
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
43+
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
44+
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
45+
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
46+
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore
3947
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
4048
i_b = i_bh // H
4149
i_h = i_bh % H
@@ -57,9 +65,9 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype),
5765
h_shared: tl.layout.make_swizzled_layout(h_shared),
5866
s_shared: tl.layout.make_swizzled_layout(s_shared),
5967
})
60-
T.use_swizzle(8)
68+
T.use_swizzle(10)
6169

62-
for i in T.Pipelined(0, NT, num_stages=1):
70+
for i in T.Pipelined(0, NT, num_stages=2):
6371
for row, col in T.Parallel(chunk_size, BK):
6472
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
6573
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
@@ -71,16 +79,16 @@ def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype),
7179

7280
T.gemm(s_shared, v, o, clear_accum=True)
7381
T.copy(h, h_shared)
74-
T.gemm(q, h_shared, o)
7582
T.gemm(k, v, h, transpose_A=True)
83+
T.gemm(q, h_shared, o)
7684
T.copy(
7785
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
7886
i_v * BV:(i_v + 1) * BV])
7987

8088
# Output final state
8189
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
8290

83-
return main
91+
return chunk_linear_attn_fwd
8492

8593

8694
def postprocess(o, h):
@@ -91,8 +99,8 @@ def postprocess(o, h):
9199
def main():
92100
parser = argparse.ArgumentParser()
93101
parser.add_argument('--B', type=int, default=8, help='Batch size')
94-
parser.add_argument('--S', type=int, default=2048, help='Seq len')
95-
parser.add_argument('--H', type=int, default=64, help='Num heads')
102+
parser.add_argument('--S', type=int, default=4096, help='Seq len')
103+
parser.add_argument('--H', type=int, default=32, help='Num heads')
96104
parser.add_argument('--D', type=int, default=256, help='Head dim')
97105
args = parser.parse_args()
98106
B, S, H, D = args.B, args.S, args.H, args.D
@@ -114,7 +122,7 @@ def main():
114122
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0],
115123
warmup=25,
116124
rep=100)
117-
t2 = do_bench(lambda: kernel(q, k, v)[0].sum(0), warmup=25, rep=100)
125+
t2 = do_bench(lambda: postprocess(*kernel(q, k, v)), warmup=25, rep=100)
118126
print(f'Triton latency: {t1:.3f} ms')
119127
print(f'TileLang latency: {t2:.3f} ms')
120128
print(f'Speedup: {t1/t2:.3f}x')
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import tilelang as tl
6+
import tilelang.language as T
7+
from tilelang.profiler import do_bench
8+
9+
import argparse
10+
11+
12+
@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
13+
def chunk_retention_fwd_kernel(
14+
B,
15+
S,
16+
H,
17+
DK,
18+
DV,
19+
dtype: str = 'float16',
20+
scale: float = None,
21+
) -> torch.Tensor:
22+
23+
if scale is None:
24+
scale = DK**-0.5
25+
accum_dtype = 'float'
26+
27+
chunk_size = 64
28+
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
29+
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
30+
NK = tl.cdiv(DK, BK)
31+
NV = tl.cdiv(DV, BV)
32+
NT = tl.cdiv(S, chunk_size)
33+
34+
@T.prim_func
35+
def chunk_retention_fwd(
36+
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
37+
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
38+
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
39+
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
40+
):
41+
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
42+
i_b = i_bh // H
43+
i_h = i_bh % H
44+
log_decay = T.alloc_var('float32')
45+
log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay
46+
47+
q = T.alloc_shared([chunk_size, BK], dtype)
48+
k = T.alloc_shared([chunk_size, BK], dtype)
49+
v = T.alloc_shared([chunk_size, BV], dtype)
50+
h = T.alloc_fragment([BK, BV], accum_dtype)
51+
h_shared = T.alloc_shared([BK, BV], dtype)
52+
s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
53+
s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
54+
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
55+
T.clear(h)
56+
57+
T.annotate_layout({
58+
q: tl.layout.make_swizzled_layout(q),
59+
k: tl.layout.make_swizzled_layout(k),
60+
v: tl.layout.make_swizzled_layout(v),
61+
h_shared: tl.layout.make_swizzled_layout(h_shared),
62+
s_shared: tl.layout.make_swizzled_layout(s_shared),
63+
})
64+
T.use_swizzle(10)
65+
66+
for i in T.Pipelined(0, NT):
67+
for row, col in T.Parallel(chunk_size, BK):
68+
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
69+
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
70+
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
71+
72+
T.gemm(q, k, s, clear_accum=True, transpose_B=True)
73+
for row, col in T.Parallel(chunk_size, chunk_size):
74+
s_shared[row,
75+
col] = T.if_then_else(row >= col, s[row, col] * T.exp2(
76+
(row - col) * log_decay), 0)
77+
78+
T.copy(h, h_shared)
79+
T.gemm(q, h_shared, o, clear_accum=True)
80+
for row, col in T.Parallel(chunk_size, BV):
81+
o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col]
82+
T.gemm(s_shared, v, o)
83+
84+
for row, col in T.Parallel(chunk_size, BV):
85+
v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay)
86+
for row, col in T.Parallel(BK, BV):
87+
h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col]
88+
T.copy(
89+
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
90+
i_v * BV:(i_v + 1) * BV])
91+
T.gemm(k, v, h, transpose_A=True)
92+
93+
return chunk_retention_fwd
94+
95+
96+
def postprocess(o):
97+
return o if o.size(0) == 1 else o.sum(0)
98+
99+
100+
def main():
101+
parser = argparse.ArgumentParser()
102+
parser.add_argument('--B', type=int, default=8, help='Batch size')
103+
parser.add_argument('--S', type=int, default=4096, help='Seq len')
104+
parser.add_argument('--H', type=int, default=32, help='Num heads')
105+
parser.add_argument('--D', type=int, default=128, help='Head dim')
106+
args = parser.parse_args()
107+
B, S, H, D = args.B, args.S, args.H, args.D
108+
total_flops = 2.0 * B * S * S * H * D # causal
109+
110+
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
111+
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
112+
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
113+
114+
kernel = chunk_retention_fwd_kernel(B, S, H, D, D)
115+
116+
t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100)
117+
print(f'Tilelang latency: {t:.3f} ms')
118+
print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}')
119+
120+
121+
if __name__ == '__main__':
122+
main()

0 commit comments

Comments
 (0)