Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/amd/example_amd_flash_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(
bx_loop_var = T.alloc_var(T.int32)
bx_loop_var = b_split

with T.While(bx_loop_var < num_q_blocks):
while bx_loop_var < num_q_blocks:
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
Expand Down
2 changes: 1 addition & 1 deletion examples/amd/example_amd_flash_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main(
bx = T.alloc_var(T.int32)
bx = b_split

with T.While(bx < num_q_blocks):
while bx < num_q_blocks:
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
Expand Down
2 changes: 1 addition & 1 deletion examples/deepseek_v32/topk_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def tl_topk_kernel(
# stage 2: tail pass
for round in T.serial(4):
if l_new_topk <= 0:
T.loop_break()
break

r_idx = round % 2
l_start_pos = topk - l_new_topk
Expand Down
15 changes: 7 additions & 8 deletions examples/gdn/example_chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,15 @@ def kernel(
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
A_fragment[i_s1, i_s2] = T.if_then_else(
G_diff_local[i_s1, i_s2] <= 0,
A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]),
0,
)

for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
if i_s1 < i_s2:
A_fragment[i_s1, i_s2] = 0

T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared)
T.copy(A_fragment, A_shared)
Expand Down
15 changes: 7 additions & 8 deletions examples/gdn/example_chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,15 @@ def kernel(
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
A_fragment[i_s1, i_s2] = T.if_then_else(
G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2,
A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]),
0,
)
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
if i_s1 <= i_s2:
A_fragment[i_s1, i_s2] = 0

T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :])
Expand Down
25 changes: 12 additions & 13 deletions examples/gdn/example_wy_fast_bwd_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,26 +345,25 @@ def kernel(
T.copy(dA_shared, dA_fragment)

for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
dA_fragment[i_s1, i_s2] = 0
if i_s1 <= i_s2:
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True)
T.copy(dA_fragment, dA_shared)
T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2):
with T.Then():
dA_fragment[i_s1, i_s2] = 0
with T.Else():
dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2]
dA_fragment[i_s1, i_s2] = T.if_then_else(
i_s1 <= i_s2,
0,
-dA_fragment[i_s1, i_s2],
)

for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh])
with T.Else():
dA_fragment[i_s1, i_s2] = 0
dA_fragment[i_s1, i_s2] = T.if_then_else(
G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0,
dA_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]),
0,
)
T.copy(dA_fragment, dA_shared)

# acceptable dA diff
Expand Down
Loading