Skip to content

Commit f07f31c

Browse files
Alex4210987xinxyxiaoLeiWang1999
authored
[AMD] Fix amd tir&add examples (#784)
* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py * Enhance AMD example script and update CI workflows - Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization. - Added new CI workflows for AMD and documentation publishing. - Updated various requirements files to include necessary dependencies. - Introduced new test cases and examples for better coverage and functionality. - Refactored existing code for improved readability and maintainability. * Remove redundant tool cache cleanup step in AMD CI workflow * Remove `torch` dependency from `requirements-rocm.txt` to streamline requirements. * Add new AMD FlashAttention example and test script - Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang. - Added `test.sh` script to facilitate running the new example with specified parameters. - Enhanced the overall structure and organization of the example for better clarity and usability. * Update configurations in `example_amd_flash_attn_fwd.py` for autotuner - Reduced the number of threads and `num_split_q` options for improved performance. - Adjusted `panel_size` options to streamline configuration settings. * Update submodule 'tvm' to commit 6ccc74f622c7ec4ac25d430d0f6546e7b9edb217 * Update submodule 'tvm' to commit 14ff70ab142b9e5a31bbf9c7923c8a697d41e86c * Add example for AMD Flash Attention backward pass implementation - Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang. - Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps. - Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters. - Included reference implementation for validation against PyTorch's attention mechanism. This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications. * Enhance AMD Flash Attention example with additional testing capabilities - Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation. - Improved the main function to allow for better parameter configuration and benchmarking. - Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example. This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications. * Update submodule TVM to commit a64a5926a6e59f5417ef2501f9d88b467337cf6a * Refactor HIP intrinsic rules to CUDA - Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules. - Adjusted include paths for better organization and clarity in the code structure. * Update AMD CI workflow to uninstall specific PyTorch packages before installation - Removed the installation of `flash_attn==2.5.8` to streamline the CI process. - Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts. * Remove unused shared memory allocations in AMD Flash Attention backward example - Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance. - This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead. * Remove unnecessary pip uninstall command from AMD CI workflow - Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions. - This change simplifies the CI process and reduces potential overhead during package management. * Refactor DispatchHIPWarpActiveMask function in HIP intrinsic rules - Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity. - Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues. * Refactor formatting of HIP intrinsic rule registrations - Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining. - No functional changes were made; this update focuses on code style improvements to enhance maintainability. * Update file name and documentation for HIP intrinsic rules - Renamed the file from `intrin_rule_cuda.cc` to `intrin_rule_hip.cc` to accurately reflect the focus on HIP intrinsic rules. - Updated the file documentation to clarify its purpose as related to HIP rather than CUDA. * Enhance DispatchHIPShuffle function with clang-analyzer comments - Added NOLINTBEGIN and NOLINTEND comments to the DispatchHIPShuffle function to suppress clang-analyzer warnings related to inner pointer usage. - This change improves code clarity and maintains compliance with static analysis tools. * lint fix * fix --------- Co-authored-by: xinxyxiao <[email protected]> Co-authored-by: Lei Wang <[email protected]> Co-authored-by: LeiWang1999 <[email protected]>
1 parent 3cfefc8 commit f07f31c

File tree

5 files changed

+665
-5
lines changed

5 files changed

+665
-5
lines changed
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import tilelang
4+
from tilelang.autotuner import *
5+
import tilelang.language as T
6+
import argparse
7+
8+
9+
@tilelang.jit(out_idx=[3, 4])
10+
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
11+
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
12+
head_kv = heads // groups
13+
q_shape = [batch, seq_len, heads, dim_qk]
14+
k_shape = [batch, seq_len, head_kv, dim_qk]
15+
v_shape = [batch, seq_len, head_kv, dim_v]
16+
dtype = "float16"
17+
accum_dtype = "float"
18+
19+
@T.prim_func
20+
def flash_fwd(
21+
Q: T.Tensor(q_shape, dtype), # type: ignore
22+
K: T.Tensor(k_shape, dtype), # type: ignore
23+
V: T.Tensor(v_shape, dtype), # type: ignore
24+
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
25+
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
26+
):
27+
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
28+
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
29+
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
30+
V_shared = T.alloc_shared([block_N, dim_v], dtype)
31+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
32+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
33+
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
34+
scores_max = T.alloc_fragment([block_M], accum_dtype)
35+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
36+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
37+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
38+
logsum = T.alloc_fragment([block_M], accum_dtype)
39+
40+
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
41+
T.fill(acc_o, 0)
42+
T.fill(logsum, 0)
43+
T.fill(scores_max, -T.infinity(accum_dtype))
44+
loop_range = (
45+
T.ceildiv(
46+
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
47+
for k in T.Pipelined(loop_range, num_stages=1):
48+
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
49+
if is_causal:
50+
for i, j in T.Parallel(block_M, block_N):
51+
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
52+
-T.infinity(acc_s.dtype))
53+
else:
54+
T.clear(acc_s)
55+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
56+
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
57+
T.copy(scores_max, scores_max_prev)
58+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
59+
for i in T.Parallel(block_M):
60+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
61+
for i, j in T.Parallel(block_M, dim_v):
62+
acc_o[i, j] *= scores_scale[i]
63+
for i, j in T.Parallel(block_M, block_N):
64+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
65+
T.copy(acc_s, acc_s_cast)
66+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
67+
T.reduce_sum(acc_s, scores_sum, dim=1)
68+
for i in T.Parallel(block_M):
69+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
70+
for i, j in T.Parallel(block_M, dim_v):
71+
acc_o[i, j] /= logsum[i]
72+
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
73+
for i in T.Parallel(block_M):
74+
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
75+
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
76+
77+
return flash_fwd
78+
79+
80+
@tilelang.jit(out_idx=[2])
81+
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
82+
dtype = "float16"
83+
accum_dtype = "float"
84+
shape = [batch, seq_len, heads, dim_v]
85+
blk = 32
86+
87+
@T.prim_func
88+
def flash_bwd_prep(
89+
O: T.Tensor(shape, dtype), # type: ignore
90+
dO: T.Tensor(shape, dtype), # type: ignore
91+
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
92+
):
93+
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
94+
o = T.alloc_fragment([blk, blk], dtype)
95+
do = T.alloc_fragment([blk, blk], dtype)
96+
acc = T.alloc_fragment([blk, blk], accum_dtype)
97+
delta = T.alloc_fragment([blk], accum_dtype)
98+
T.clear(acc)
99+
for k in range(T.ceildiv(dim_v, blk)):
100+
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
101+
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
102+
for i, j in T.Parallel(blk, blk):
103+
acc[i, j] += o[i, j] * do[i, j]
104+
T.reduce_sum(acc, delta, 1)
105+
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
106+
107+
return flash_bwd_prep
108+
109+
110+
def make_dq_layout(dQ):
111+
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
112+
return T.Layout(dQ.shape,
113+
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
114+
115+
116+
@tilelang.jit(out_idx=[1])
117+
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
118+
dtype = "float16"
119+
accum_dtype = "float"
120+
shape = [batch, seq_len, heads, dim_qk]
121+
blk = 64
122+
123+
@T.prim_func
124+
def flash_bwd_post(
125+
dQ: T.Tensor(shape, accum_dtype), # type: ignore
126+
dQ_out: T.Tensor(shape, dtype), # type: ignore
127+
):
128+
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
129+
T.copy(
130+
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
131+
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
132+
)
133+
134+
return flash_bwd_post
135+
136+
137+
@tilelang.jit
138+
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
139+
sm_scale = (1.0 / dim_qk)**0.5
140+
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
141+
head_kv = heads // groups
142+
q_shape = [batch, seq_len, heads, dim_qk]
143+
k_shape = [batch, seq_len, head_kv, dim_qk]
144+
v_shape = [batch, seq_len, head_kv, dim_v]
145+
dtype = "float16"
146+
accum_dtype = "float"
147+
148+
@T.prim_func
149+
def flash_bwd(
150+
Q: T.Tensor(q_shape, dtype), # type: ignore
151+
K: T.Tensor(k_shape, dtype), # type: ignore
152+
V: T.Tensor(v_shape, dtype), # type: ignore
153+
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
154+
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
155+
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
156+
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
157+
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
158+
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
159+
):
160+
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
161+
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
162+
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
163+
q = T.alloc_shared([block_N, dim_qk], dtype)
164+
V_shared = T.alloc_shared([block_M, dim_v], dtype)
165+
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
166+
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
167+
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
168+
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
169+
lse_shared = T.alloc_shared([block_N], accum_dtype)
170+
delta = T.alloc_shared([block_N], accum_dtype)
171+
do = T.alloc_shared([block_N, dim_v], dtype)
172+
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
173+
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
174+
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
175+
176+
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
177+
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
178+
T.clear(dv)
179+
T.clear(dk)
180+
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
181+
loop_ed = T.ceildiv(seq_len, block_N)
182+
for k in T.Pipelined(loop_st, loop_ed, num_stages=1):
183+
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
184+
T.clear(qkT)
185+
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
186+
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
187+
for i, j in T.Parallel(block_M, block_N):
188+
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
189+
if is_causal:
190+
for i, j in T.Parallel(block_M, block_N):
191+
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
192+
0)
193+
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
194+
T.clear(dsT)
195+
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
196+
T.copy(qkT, qkT_cast)
197+
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
198+
199+
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
200+
201+
for i, j in T.Parallel(block_M, block_N):
202+
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
203+
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
204+
205+
T.copy(dsT_cast, dsT_shared)
206+
T.clear(dq)
207+
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
208+
for i, j in T.Parallel(block_N, dim_qk):
209+
if k * block_N + i < seq_len:
210+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
211+
212+
for i, j in T.Parallel(block_M, dim_v):
213+
T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j])
214+
for i, j in T.Parallel(block_M, dim_qk):
215+
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk[i, j])
216+
217+
return flash_bwd
218+
219+
220+
@torch.compile
221+
class _attention(torch.autograd.Function):
222+
223+
@staticmethod
224+
def forward(ctx, q, k, v, causal, groups=1):
225+
BATCH, N_CTX, H, D_HEAD_QK = q.shape
226+
D_HEAD_V = v.shape[-1]
227+
block_M = 128
228+
block_N = 64
229+
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
230+
o, lse = mod(q, k, v)
231+
ctx.save_for_backward(q, k, v, o, lse)
232+
ctx.causal = causal
233+
return o
234+
235+
@staticmethod
236+
def backward(ctx, do):
237+
q, k, v, o, lse = ctx.saved_tensors
238+
BATCH, N_CTX, H, D_HEAD_QK = q.shape
239+
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1]
240+
groups = H // HEAD_KV
241+
242+
def maybe_contiguous(x):
243+
if x.stride(-1) != 1:
244+
return x.contiguous()
245+
return x
246+
247+
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
248+
block_M = 64
249+
block_N = 32
250+
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
251+
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
252+
delta = mod_prep(o, do)
253+
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
254+
groups)
255+
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
256+
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
257+
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
258+
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
259+
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
260+
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
261+
kernel(q, k, v, do, lse, delta, dq, dk, dv)
262+
dq = mod_post(dq)
263+
return dq, dk, dv, None, None
264+
265+
266+
attention = _attention.apply
267+
268+
269+
def ref_program(Q, K, V, is_causal, groups=1):
270+
# Q: [B, T, HQ, D_QK]
271+
# K: [B, T, HK, D_QK]
272+
# V: [B, T, HV, D_V]
273+
# HQ = HKV * groups
274+
assert Q.size(2) == K.size(
275+
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
276+
assert Q.size(2) == V.size(
277+
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
278+
279+
dim_qk = Q.size(-1)
280+
K = K.repeat_interleave(groups, dim=2)
281+
V = V.repeat_interleave(groups, dim=2)
282+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
283+
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
284+
if is_causal:
285+
seq_len = Q.size(1)
286+
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
287+
mask = mask.unsqueeze(0).unsqueeze(0)
288+
scores = scores.masked_fill(mask == 0, float('-inf'))
289+
attention_weights = F.softmax(scores, dim=-1)
290+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
291+
return output
292+
293+
294+
def main(BATCH: int = 1,
295+
H: int = 32,
296+
N_CTX: int = 256,
297+
D_HEAD_QK: int = 192,
298+
D_HEAD_V: int = 128,
299+
groups: int = 16,
300+
causal: bool = False):
301+
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
302+
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
303+
total_flops = 3 * flops_per_qk + 2 * flops_per_v
304+
if causal:
305+
total_flops *= 0.5
306+
Q = (
307+
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
308+
device="cuda").normal_().requires_grad_())
309+
310+
head_kv = H // groups
311+
K = (
312+
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
313+
device="cuda").normal_().requires_grad_())
314+
V = (
315+
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
316+
device="cuda").normal_().requires_grad_())
317+
dO = (
318+
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
319+
device="cuda").normal_().requires_grad_())
320+
O = attention(Q, K, V, causal, groups)
321+
O.backward(dO, retain_graph=True)
322+
dQ, Q.grad = Q.grad.clone(), None
323+
dK, K.grad = K.grad.clone(), None
324+
dV, V.grad = V.grad.clone(), None
325+
326+
O_ref = ref_program(Q, K, V, causal, groups)
327+
O_ref.backward(dO, retain_graph=True)
328+
dQ_ref, Q.grad = Q.grad.clone(), None
329+
dK_ref, K.grad = K.grad.clone(), None
330+
dV_ref, V.grad = V.grad.clone(), None
331+
332+
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
333+
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
334+
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
335+
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
336+
337+
def run():
338+
O_ref.backward(dO, retain_graph=True)
339+
340+
def run1():
341+
O.backward(dO, retain_graph=True)
342+
343+
from tilelang.profiler import do_bench
344+
345+
latency = do_bench(run, warmup=500)
346+
print("torch: {:.2f} ms".format(latency))
347+
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
348+
latency = do_bench(run1, warmup=500)
349+
print("tilelang: {:.2f} ms".format(latency))
350+
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
351+
352+
353+
if __name__ == "__main__":
354+
parser = argparse.ArgumentParser()
355+
parser.add_argument('--batch', type=int, default=8, help='Batch size')
356+
parser.add_argument('--h', type=int, default=32, help='Number of heads')
357+
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
358+
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
359+
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
360+
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
361+
parser.add_argument('--groups', type=int, default=16, help='groups')
362+
args = parser.parse_args()
363+
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)

examples/amd/example_amd_flash_attn_fwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def get_configs():
3232
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
3333
block_M = [32, 64, 128, 256]
3434
block_N = [32, 64, 128, 256]
35-
threads = [64, 128, 192, 256, 512, 1024]
36-
num_split_q = [32, 64, 128, 256, 256]
35+
threads = [128, 256, 512]
36+
num_split_q = [64, 128, 256]
3737
num_stages = [0]
3838
enable_rasterization = [True]
3939
k_pack = [2]
40-
panel_size = [7, 8, 9, 10]
40+
panel_size = [7, 8]
4141
qk_coalesced_width = [8]
4242
v_coalesced_width = [4]
4343

examples/amd/test.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
/root/miniconda3/envs/py312/bin/python3 examples/amd/example_amd_flash_attn_fwd.py \
2+
--batch 2 \
3+
--heads 16 \
4+
--seq_len 4096 \
5+
--dim 128 \
6+
--is_causal \
7+
--groups 2
8+
9+
/root/composable_kernel/build/bin/tile_example_fmha_fwd \
10+
-b=2 -h=16 -s=4096 -d=128 -mask=t -v=1 -warmup=5 -repeat=20

0 commit comments

Comments
 (0)