diff --git a/third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py b/third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py index 88cb53a4dc4d..151999fa6612 100644 --- a/third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py +++ b/third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py @@ -342,10 +342,14 @@ def gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel(a_ptr, b_ptr, c_ptr, loop_ub = ttgl.cdiv(K, BLOCK_K) epilogue_lb = loop_ub - (NUM_BUFFERS - 1) + + pred = 0 - epilogue_lb + pred = (pred >> 31) & 1 + producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B, + pred=pred) + ttgl.assume(loop_ub > 0) for i in range(0, loop_ub): - pred = i - epilogue_lb - pred = (pred >> 31) & 1 # SubIteration0 # LDS load SubIteration1 a1, b1 = lds_subtile_load(consumer, SUBTILE_LEN, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, @@ -354,11 +358,6 @@ def gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel(a_ptr, b_ptr, c_ptr, accumulator = ttgl.amd.gfx1250.wmma(a0, b0, accumulator) # SubIteration1 - # TDM load for next tile - # If we are in epilogue, we have already issued our tile loads - producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B, - pred=pred) - # We prefetch distance - 1 iterations ahead because producer is already incremented by 1 issue_l2_prefetches(L2_PREFETCH_DISTANCE - 1, producer, a_desc, b_desc, 0, 0, BLOCK_K, TRANSPOSE_B) @@ -378,6 +377,12 @@ def gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel(a_ptr, b_ptr, c_ptr, # SubIteration3 consumer += 1 ttgl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 2) * 2) + # TDM load for next tile + # If we are in epilogue, we have already issued our tile loads + pred = (i + 1) - epilogue_lb + pred = (pred >> 31) & 1 + producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B, + pred=pred) # LDS load SubIteration0 for next tile a0, b0 = lds_subtile_load(consumer, 0, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS, TRANSPOSE_B, SUBTILE_LEN)