From dbf707db30cdce9618720e89b617780f60db460c Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 19 Jan 2026 15:09:54 +0800 Subject: [PATCH 1/2] use python-side control flow keywords --- examples/amd/example_amd_flash_attn_bwd.py | 2 +- examples/amd/example_amd_flash_attn_fwd.py | 2 +- examples/deepseek_v32/topk_selector.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index 788aec367..27986ce78 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -112,7 +112,7 @@ def main( bx_loop_var = T.alloc_var(T.int32) bx_loop_var = b_split - with T.While(bx_loop_var < num_q_blocks): + while bx_loop_var < num_q_blocks: acc_o = T.alloc_fragment([block_M, dim], accum_dtype) m_i = T.alloc_fragment([block_M], accum_dtype) l_i = T.alloc_fragment([block_M], accum_dtype) diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index ca9c361ff..581619220 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -124,7 +124,7 @@ def main( bx = T.alloc_var(T.int32) bx = b_split - with T.While(bx < num_q_blocks): + while bx < num_q_blocks: acc_o = T.alloc_fragment([block_M, dim], accum_dtype) m_i = T.alloc_fragment([block_M], accum_dtype) l_i = T.alloc_fragment([block_M], accum_dtype) diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 82d59bc78..efdfb3e08 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -115,7 +115,7 @@ def tl_topk_kernel( # stage 2: tail pass for round in T.serial(4): if l_new_topk <= 0: - T.loop_break() + break r_idx = round % 2 l_start_pos = topk - l_new_topk From 55b4d46600f8559e9e36c1c920892cfced467c3e Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 19 Jan 2026 15:16:37 +0800 Subject: [PATCH 2/2] remove T.{if, then, else} --- examples/gdn/example_chunk_o.py | 15 ++++++------ examples/gdn/example_chunk_scaled_dot_kkt.py | 15 ++++++------ examples/gdn/example_wy_fast_bwd_split.py | 25 ++++++++++---------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index a4d7281f5..bb95f555f 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -127,16 +127,15 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G_diff_local[i_s1, i_s2] <= 0): - with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) - with T.Else(): - A_fragment[i_s1, i_s2] = 0 + A_fragment[i_s1, i_s2] = T.if_then_else( + G_diff_local[i_s1, i_s2] <= 0, + A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]), + 0, + ) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 < i_s2): # noqa: SIM117 - with T.Then(): - A_fragment[i_s1, i_s2] = 0 + if i_s1 < i_s2: + A_fragment[i_s1, i_s2] = 0 T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) T.copy(A_fragment, A_shared) diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index 8c7a4d573..c16374fe8 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -111,16 +111,15 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): - with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) - with T.Else(): - A_fragment[i_s1, i_s2] = 0 + A_fragment[i_s1, i_s2] = T.if_then_else( + G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2, + A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]), + 0, + ) else: for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): # noqa: SIM117 - with T.Then(): - A_fragment[i_s1, i_s2] = 0 + if i_s1 <= i_s2: + A_fragment[i_s1, i_s2] = 0 T.copy(A_fragment, A_shared) T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index de8afc2b7..822f745f2 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -345,26 +345,25 @@ def kernel( T.copy(dA_shared, dA_fragment) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): # noqa: SIM117 - with T.Then(): - dA_fragment[i_s1, i_s2] = 0 + if i_s1 <= i_s2: + dA_fragment[i_s1, i_s2] = 0 T.copy(dA_fragment, dA_shared) T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True) T.copy(dA_fragment, dA_shared) T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): - with T.Then(): - dA_fragment[i_s1, i_s2] = 0 - with T.Else(): - dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2] + dA_fragment[i_s1, i_s2] = T.if_then_else( + i_s1 <= i_s2, + 0, + -dA_fragment[i_s1, i_s2], + ) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): - with T.Then(): - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) - with T.Else(): - dA_fragment[i_s1, i_s2] = 0 + dA_fragment[i_s1, i_s2] = T.if_then_else( + G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0, + dA_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]), + 0, + ) T.copy(dA_fragment, dA_shared) # acceptable dA diff