diff --git a/python/examples/gluon/01-attention-forward.py b/python/examples/gluon/01-attention-forward.py index 88a2e5500e2d..f492605ba129 100644 --- a/python/examples/gluon/01-attention-forward.py +++ b/python/examples/gluon/01-attention-forward.py @@ -163,7 +163,7 @@ def get_desc_channel(desc, num_buffers: gl.constexpr, num_consumers: gl.constexp @gluon.jit def issue_async_tma_load(smem, bar, desc, offset): mbarrier.expect(bar, desc.block_type.nbytes) - tma.async_copy_global_to_shared(desc, [offset, 0], bar, smem) + tma.async_load(desc, [offset, 0], bar, smem) # ===-----------------------------------------------------------------------===# diff --git a/python/examples/gluon/02-convolution.py b/python/examples/gluon/02-convolution.py index 046e78e82248..fa82a69cde62 100644 --- a/python/examples/gluon/02-convolution.py +++ b/python/examples/gluon/02-convolution.py @@ -177,7 +177,7 @@ def load_partition(p): mbarrier.expect(ready_bar, p.in_desc.block_type.nbytes + p.weight_desc.block_type.nbytes) # Input tile via TMA im2col: channel offset iter_ci*BLOCK_K (OOB channels zero-filled) - tma.async_copy_global_to_shared_im2col( + tma.async_load_im2col( p.in_desc, [ batch_id, out_y * config.stride_h - config.pad_h, out_x * config.stride_w - config.pad_w, @@ -193,7 +193,7 @@ def load_partition(p): # ci block bleeds into the next (r,s) group, those weight elements are multiplied # by zero-filled input channels, so the result is still correct. k_offset = (iter_r * config.S + iter_s) * config.Ci + iter_ci * BLOCK_K - tma.async_copy_global_to_shared( + tma.async_load( p.weight_desc, [prog.pid_n * BLOCK_N, k_offset], ready_bar, diff --git a/python/examples/gluon/03-matmul-multicta.py b/python/examples/gluon/03-matmul-multicta.py index 623fcb5948c0..72fb0ac55da2 100644 --- a/python/examples/gluon/03-matmul-multicta.py +++ b/python/examples/gluon/03-matmul-multicta.py @@ -383,8 +383,8 @@ def matmul_load_partition(p): mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred) bar = p.load_ready_bars.index(state.index) mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True) - tma.async_copy_global_to_shared(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True) + tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True) + tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True) state = state.next() scheduler = scheduler.step(i) i += 1 diff --git a/python/examples/gluon/04-2cta-block-scale-matmul.py b/python/examples/gluon/04-2cta-block-scale-matmul.py index a46d400a87f3..176aadcd5920 100644 --- a/python/examples/gluon/04-2cta-block-scale-matmul.py +++ b/python/examples/gluon/04-2cta-block-scale-matmul.py @@ -298,12 +298,11 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale mbarrier.expect( bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta, pred) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred) - tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, - a_scale_bufs.index(index), pred) - tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, - b_scale_bufs.index(index), pred, multicast=multicast_b_scale) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred) + tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred, + multicast=multicast_b_scale) return producer.next(pred) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index bf3064ad8f2c..8d93e14a8c1b 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -183,7 +183,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + tma.async_load(input_desc, [0, 0], bar, smem) mbarrier.wait(bar, 0, pred=(not FAILURE), deps=[smem]) val = smem.load(blocked_layout) mbarrier.wait(bar, 0, pred=FAILURE, deps=[smem]) @@ -231,7 +231,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True) + tma.async_load(input_desc, [0, 0], bar, smem, multicast=True) mbarrier.wait(bar, 0, pred=(not FAILURE), deps=[smem]) val = smem.load(blocked_layout) mbarrier.wait(bar, 0, pred=FAILURE, deps=[smem]) @@ -313,7 +313,7 @@ def kernel(input_desc, out): for phase in ttgl.static_range(2): mbarrier.expect(bar, input_desc.nbytes_per_cta) ttgl.barrier(cluster=True) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True) + tma.async_load(input_desc, [0, 0], bar, smem, multicast=True) mbarrier.wait(bar, phase % 2, deps=[smem]) ttgl.barrier(cluster=True) val = smem.load(blocked_layout) @@ -357,7 +357,7 @@ def kernel(input_desc, out): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True) + tma.async_load(input_desc, [0, 0], bar, smem, multicast=True) ttgl.barrier(cluster=True) smem.store(ttgl.full([XBLOCK, XBLOCK], 1, ttgl.float16, blocked_layout)) mbarrier.wait(bar, 0, deps=[smem]) @@ -405,8 +405,8 @@ def kernel(a_desc, b_desc, out, FAILURE: ttgl.constexpr): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [0, 0], bar, b_smem) + tma.async_load(a_desc, [0, 0], bar, a_smem) + tma.async_load(b_desc, [0, 0], bar, b_smem) mbarrier.wait(bar, 0, pred=(not FAILURE), deps=[a_smem, b_smem]) val = a_smem.load(blocked_layout) val = val + b_smem.load(blocked_layout) @@ -453,7 +453,7 @@ def kernel(input_desc, out, EXPECT_DELTA: ttgl.constexpr): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta + EXPECT_DELTA) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + tma.async_load(input_desc, [0, 0], bar, smem) mbarrier.wait(bar, 0, deps=[smem]) val = smem.load(blocked_layout) mbarrier.invalidate(bar) @@ -499,8 +499,8 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): mbarrier.init(bar.index(1), count=1) mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta) mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(0), smem.index(0)) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smem.index(1)) + tma.async_load(input_desc, [0, 0], bar.index(0), smem.index(0)) + tma.async_load(input_desc, [0, 0], bar.index(1), smem.index(1)) mbarrier.wait(bar.index(0), 0, deps=[smem.index(0)]) if not FAILURE: @@ -557,9 +557,9 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): mbarrier.init(bar.index(1), count=1) mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(0), smem.index(0)) + tma.async_load(input_desc, [0, 0], bar.index(0), smem.index(0)) mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smem.index(1)) + tma.async_load(input_desc, [0, 0], bar.index(1), smem.index(1)) mbarrier.wait(bar.index(1), 0) @@ -742,7 +742,7 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt if MEM_ACCESS_KIND == "tma_cp": mbarrier.expect(tma_bar, input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], tma_bar, smemA) + tma.async_load(input_desc, [0, 0], tma_bar, smemA) mbarrier.wait(tma_bar, 0) mbarrier.invalidate(tma_bar) elif MEM_ACCESS_KIND == "local_store": @@ -1052,18 +1052,18 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): # ins_id = 0 mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) + tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) + tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) ins_id = inc_mod(ins_id, num_buffers) # ins_id = 1 mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) + tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) + tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) ins_id = inc_mod(ins_id, num_buffers) mbarrier.wait(barLoadA.index(ext_id), phase) @@ -1078,10 +1078,10 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): for i in range(ub): if i < ub - 2: mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) + tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) + tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) ins_id = inc_mod(ins_id, num_buffers) if i < ub - 1: @@ -1169,8 +1169,8 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): for k in range(num_k_tiles): offs_k = k * XBLOCK mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, offs_k], tma_bar, smemA, multicast=True) - tma.async_copy_global_to_shared(b_desc, [offs_k, 0], tma_bar, smemB, multicast=True) + tma.async_load(a_desc, [0, offs_k], tma_bar, smemA, multicast=True) + tma.async_load(b_desc, [offs_k, 0], tma_bar, smemB, multicast=True) if not FAILURE: mbarrier.wait(tma_bar, phase_tma, deps=[smemA, smemB]) blackwell.tcgen05_mma(smemA, smemB, acc, use_acc=k != 0, multicast=True, mbarriers=[mma_bar]) @@ -1239,10 +1239,10 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): # ins_id = 0 mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) + tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) + tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) ins_id = inc_mod(ins_id, num_buffers) # ins_id = 1 @@ -1250,10 +1250,10 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): for i in range(ub): if i < ub - 1: mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) + tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id)) mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) + tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id)) ins_id = inc_mod(ins_id, num_buffers) mbarrier.wait(barLoadA.index(ext_id), phase) @@ -2158,13 +2158,13 @@ def test_deadlock_exempt_when_tma_signals(device, run_wrapper, monkeypatch, num_ @gluon.jit def ws_default(input_desc, smem, bar): mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(0), smem.index(0)) + tma.async_load(input_desc, [0, 0], bar.index(0), smem.index(0)) mbarrier.wait(bar.index(0), phase=0) @gluon.jit def ws_1(input_desc, smem, bar): mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smem.index(1)) + tma.async_load(input_desc, [0, 0], bar.index(1), smem.index(1)) mbarrier.wait(bar.index(1), phase=0) @gluon.jit diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 8e3fbca582ea..2363b2024511 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -120,7 +120,7 @@ def tma_im2col_kernel(in_desc, out_desc): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, in_desc.block_type.nbytes) - tma.async_copy_global_to_shared_im2col(in_desc, [0, 0, 0, 0], [0, 0], bar, smem) + tma.async_load_im2col(in_desc, [0, 0, 0, 0], [0, 0], bar, smem) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) tma.async_copy_shared_to_global(out_desc, [0, 0], smem) @@ -171,7 +171,7 @@ def tma_round_f32_to_tf32_kernel(in_desc, out_desc): bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, in_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem) + tma.async_load(in_desc, [0, 0], bar, smem) mbarrier.wait(bar, phase=0, deps=[smem]) mbarrier.invalidate(bar) tma.async_copy_shared_to_global(out_desc, [0, 0], smem) @@ -220,7 +220,7 @@ def tma_multicast_copy_kernel(in_desc, out_desc): mbarrier.init(bar, count=1) mbarrier.expect(bar, in_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem, multicast=True) + tma.async_load(in_desc, [0, 0], bar, smem, multicast=True) mbarrier.wait(bar, phase=0, deps=[smem]) tma.async_copy_shared_to_global(out_desc, [0, 0], smem) @@ -356,8 +356,8 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl. mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True)) mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], tma_bar, smem_a, multicast=True) - tma.async_copy_global_to_shared(b_desc, [0, 0], tma_bar, smem_b, multicast=True) + tma.async_load(a_desc, [0, 0], tma_bar, smem_a, multicast=True) + tma.async_load(b_desc, [0, 0], tma_bar, smem_b, multicast=True) mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b]) mbarrier.invalidate(tma_bar) @@ -489,8 +489,8 @@ def tcgen05_mma_scaled_direct_multicast_kernel(a_desc, b_desc, out_ptr, BLOCK_M: for k in range(NUM_K_TILES): offs_k = k * BLOCK_K mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, offs_k], tma_bar, smem_a, multicast=True) - tma.async_copy_global_to_shared(b_desc, [offs_k, 0], tma_bar, smem_b, multicast=True) + tma.async_load(a_desc, [0, offs_k], tma_bar, smem_a, multicast=True) + tma.async_load(b_desc, [offs_k, 0], tma_bar, smem_b, multicast=True) mbarrier.wait(tma_bar, phase_tma, deps=[smem_a, smem_b]) tcgen05_mma_scaled(smem_a, smem_b, acc, a_scale, b_scale, "e5m2", "e5m2", use_acc=k != 0, multicast=True, mbarriers=[mma_bar]) @@ -586,7 +586,7 @@ def tma_device_load_kernel(input_ptr, output_ptr, XBLOCK: ttgl.constexpr, smem_l mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + tma.async_load(input_desc, [0, 0], bar, smem) mbarrier.wait(bar, 0) mbarrier.invalidate(bar) @@ -774,8 +774,8 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_p if use_gather_scatter: blackwell_tma.async_gather(a_desc, gather_offsets, k * BLOCK_K, tma_bar, smem_a, multicast=multicast) else: - tma.async_copy_global_to_shared(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast) - tma.async_copy_global_to_shared(b_desc, [k * BLOCK_K, 0], tma_bar, smem_b, multicast=multicast) + tma.async_load(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast) + tma.async_load(b_desc, [k * BLOCK_K, 0], tma_bar, smem_b, multicast=multicast) mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b]) phase_tma ^= 1 @@ -1758,7 +1758,7 @@ def kernel(in_desc, out_desc): mbarrier.init(bar, count=1) mbarrier.expect(bar, in_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem_slice1) + tma.async_load(in_desc, [0, 0], bar, smem_slice1) mbarrier.wait(bar, phase=0) blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0]) @@ -3830,12 +3830,12 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale EXPECTED_BYTES: ttgl.constexpr = (a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta) mbarrier.expect(tma_bar, EXPECTED_BYTES) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], tma_bar, a_smem, multicast=multicast) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], tma_bar, b_smem, multicast=multicast) - tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], tma_bar, a_scale_smem, - multicast=multicast) - tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], tma_bar, b_scale_smem, - multicast=multicast) + tma.async_load(a_desc, [off_m, off_k_a], tma_bar, a_smem, multicast=multicast) + tma.async_load(b_desc, [off_n, off_k_b], tma_bar, b_smem, multicast=multicast) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], tma_bar, a_scale_smem, + multicast=multicast) + tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], tma_bar, b_scale_smem, + multicast=multicast) mbarrier.wait(tma_bar, phase_tma, deps=[a_smem, b_smem, a_scale_smem, b_scale_smem]) phase_tma ^= 1 diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index fbca966c9c93..9a6d4d90e8d9 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -802,7 +802,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr): bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + tma.async_load(input_desc, [0, 0], bar, smem) ttgl.static_assert(input_desc.block_type.nbytes == XBLOCK * XBLOCK * 2) mbarrier.expect(bar, input_desc.block_type.nbytes) mbarrier.wait(bar, 0) @@ -3849,7 +3849,7 @@ def nv_tma_descriptor_load_kernel(input_ptr): bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float32.primitive_bitwidth // 8) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + tma.async_load(input_desc, [0, 0], bar, smem) ptr = MockTensor(ttgl.float32) module = run_parser(nv_tma_descriptor_load_kernel, *make_args(ptr), target) diff --git a/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py b/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py index b6752402bfda..33b8168d7502 100644 --- a/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py +++ b/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py @@ -4,6 +4,7 @@ __all__ = [ "async_copy_global_to_shared", + "async_load", "mbarrier_arrive", "commit_group", "wait_group", @@ -11,10 +12,9 @@ @builtin -def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", eviction_policy="", volatile=False, - _semantic=None): +def async_load(smem, pointer, mask=None, cache_modifier="", eviction_policy="", volatile=False, _semantic=None): """ - Asynchronously copy elements from global memory to shared memory. + Asynchronously load elements from global memory to shared memory. Args: smem (shared_memory_descriptor): Destination shared memory descriptor. @@ -39,6 +39,10 @@ def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", evi cache_modifier, eviction_policy, volatile) +# Backward-compatible alias +async_copy_global_to_shared = async_load + + @builtin def mbarrier_arrive(mbarrier, increment_count=True, _semantic=None): """ diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py index 42705f9e6f66..b761a295ec42 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -9,6 +9,9 @@ async_atomic_xor, async_copy_global_to_shared, async_copy_shared_to_global, + async_load, + async_load_im2col, + async_store, store_wait, tensor_descriptor, tensor_descriptor_type, @@ -27,6 +30,9 @@ "async_atomic_xor", "async_copy_global_to_shared", "async_copy_shared_to_global", + "async_load", + "async_load_im2col", + "async_store", "store_wait", "tensor_descriptor", "tensor_descriptor_type", diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py index da6f23c02768..6d2bc3085ceb 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -19,6 +19,9 @@ "async_copy_global_to_shared", "async_copy_global_to_shared_im2col", "async_copy_shared_to_global", + "async_load", + "async_load_im2col", + "async_store", "store_wait", ] @@ -192,9 +195,9 @@ def _convert_im2col_offsets(offsets, _semantic): @builtin -def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, multicast=False, _semantic=None): +def async_load(tensor_desc, coord, barrier, result, pred=True, multicast=False, _semantic=None): """ - Copy data from global memory to shared memory using TMA. + Load data from global memory to shared memory using TMA. Args: tensor_desc: Tensor descriptor (tiled) @@ -205,8 +208,7 @@ def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, multicast: Enable multicast """ if _semantic.builder.options.enable_iisan: - _emit_alignment_check(tensor_desc, coord, "async_copy_global_to_shared", "innermost coordinate", - _semantic=_semantic) + _emit_alignment_check(tensor_desc, coord, "async_load", "innermost coordinate", _semantic=_semantic) coord = _semantic._convert_to_ir_values(coord, require_i64=False) pred = _semantic.to_tensor(pred) @@ -224,10 +226,9 @@ def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, @builtin -def async_copy_global_to_shared_im2col(tensor_desc, coord, offsets, barrier, result, pred=True, multicast=False, - _semantic=None): +def async_load_im2col(tensor_desc, coord, offsets, barrier, result, pred=True, multicast=False, _semantic=None): """ - Copy data from global memory to shared memory using TMA in im2col mode. + Load data from global memory to shared memory using TMA in im2col mode. Args: tensor_desc: Tensor descriptor (im2col) @@ -242,8 +243,7 @@ def async_copy_global_to_shared_im2col(tensor_desc, coord, offsets, barrier, res multicast: Enable multicast """ if _semantic.builder.options.enable_iisan: - _emit_alignment_check(tensor_desc, coord, "async_copy_global_to_shared_im2col", "innermost coordinate", - _semantic=_semantic) + _emit_alignment_check(tensor_desc, coord, "async_load", "innermost coordinate", _semantic=_semantic) coord = _semantic._convert_to_ir_values(coord, require_i64=False) pred = _semantic.to_tensor(pred) @@ -262,9 +262,9 @@ def async_copy_global_to_shared_im2col(tensor_desc, coord, offsets, barrier, res @builtin -def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None): +def async_store(tensor_desc, coord, src, _semantic=None): """ - Copy data from shared memory to global memory using TMA. + Store data from shared memory to global memory using TMA. Args: tensor_desc (tensor_descriptor): Tensor descriptor (tiled). @@ -272,12 +272,17 @@ def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None): src (ttgl.shared_memory_descriptor): Source memory descriptor. """ if _semantic.builder.options.enable_iisan: - _emit_alignment_check(tensor_desc, coord, "async_copy_shared_to_global", "innermost coordinate", - _semantic=_semantic) + _emit_alignment_check(tensor_desc, coord, "async_store", "innermost coordinate", _semantic=_semantic) coord = _semantic._convert_to_ir_values(coord, require_i64=False) _semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle) +# Backward-compatible aliases +async_copy_global_to_shared = async_load +async_copy_global_to_shared_im2col = async_load_im2col +async_copy_shared_to_global = async_store + + def _async_atomic_shared_to_global(kind, tensor_desc, coord, src, fn_name: str, _semantic=None): if _semantic.builder.options.enable_iisan: _emit_alignment_check(tensor_desc, coord, fn_name, "innermost coordinate", _semantic=_semantic) diff --git a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py index 09245179e5de..ba6f75d5cdef 100644 --- a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py @@ -412,7 +412,7 @@ def tl_load_tensor_descriptor(desc, offsets): bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) mbarrier.expect(bar, desc.block_type.nbytes) - tma.async_copy_global_to_shared(desc, offsets, bar, smem) + tma.async_load(desc, offsets, bar, smem) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) diff --git a/python/tutorials/gluon/03-async-copy.py b/python/tutorials/gluon/03-async-copy.py index 0abe84359d0e..43a685a012c0 100644 --- a/python/tutorials/gluon/03-async-copy.py +++ b/python/tutorials/gluon/03-async-copy.py @@ -50,7 +50,7 @@ def memcpy_1d_cpasync_kernel(in_ptr, out_ptr, xnumel, XBLOCK: gl.constexpr): smem = gl.allocate_shared_memory(gl.float32, [XBLOCK], layout=smem_layout) # Issue the async copy. - cp.async_copy_global_to_shared(smem, in_ptr + offsets, mask=mask) + cp.async_load(smem, in_ptr + offsets, mask=mask) # `commit_group` puts all previously issued async copies into a group. cp.commit_group() @@ -165,8 +165,8 @@ def elementwise_add_cpasync_kernel( # mask = (xoffs < xnumel)[:, None] & (yoffs < ynumel)[None, :] # Issue loads for both A and B tiles. - cp.async_copy_global_to_shared(a_smem, a_ptrs + ystride_a * yoffs[None, :], mask=mask) - cp.async_copy_global_to_shared(b_smem, b_ptrs + ystride_b * yoffs[None, :], mask=mask) + cp.async_load(a_smem, a_ptrs + ystride_a * yoffs[None, :], mask=mask) + cp.async_load(b_smem, b_ptrs + ystride_b * yoffs[None, :], mask=mask) # Commit both loads to the same group. cp.commit_group() # Wait until both loads are complete! @@ -261,10 +261,10 @@ def issue_loads(copy_idx, a_smem, b_smem, a_ptrs, ystride_a, b_ptrs, xmask, ynum # are fewer blocks to copy than `num_buffers-1`. yoffs = copy_idx * YBLOCK + y_idx mask = xmask & (yoffs < ynumel)[None, :] - cp.async_copy_global_to_shared(a_smem.index(copy_idx % num_buffers), # - a_ptrs + ystride_a * yoffs[None, :], mask) - cp.async_copy_global_to_shared(b_smem.index(copy_idx % num_buffers), # - b_ptrs + ystride_b * yoffs[None, :], mask) + cp.async_load(a_smem.index(copy_idx % num_buffers), # + a_ptrs + ystride_a * yoffs[None, :], mask) + cp.async_load(b_smem.index(copy_idx % num_buffers), # + b_ptrs + ystride_b * yoffs[None, :], mask) cp.commit_group() return copy_idx + 1 diff --git a/python/tutorials/gluon/04-tma.py b/python/tutorials/gluon/04-tma.py index e610f55a650e..c90e2f00bbb2 100644 --- a/python/tutorials/gluon/04-tma.py +++ b/python/tutorials/gluon/04-tma.py @@ -93,7 +93,7 @@ def memcpy_1d_tma_kernel(in_desc, out_desc, XBLOCK: gl.constexpr): # decrement the number of outstanding bytes as transactions complete. When # it reaches 0, the mbarrier is arrived on once. mbarrier.expect(bar, in_desc.block_type.nbytes) - tma.async_copy_global_to_shared(in_desc, [pid * XBLOCK], bar, smem) + tma.async_load(in_desc, [pid * XBLOCK], bar, smem) # Wait for completion of the read. We query the completion state of the # mbarrier using the parity of the phase, i.e. either 0 or 1. mbarriers are @@ -158,7 +158,7 @@ def test_memcpy_1d_tma(XBLOCK, xnumel): # ```python # value = smem.load() # fence_async_shared() -# tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) +# tma.async_load(desc, [0, 0], bar, smem) # ``` # # Without the fence, async_copy_global_to_shared can start copying into `smem` @@ -177,7 +177,7 @@ def test_memcpy_1d_tma(XBLOCK, xnumel): # do not require a fence. For example, waiting on the result of a TMA load: # # ```python -# tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) +# tma.async_load(desc, [0, 0], bar, smem) # mbarrier.wait(bar, phase=0) # value = smem.load() # ``` @@ -203,8 +203,8 @@ def issue_loads(copy_index, a_desc, b_desc, a_smem, b_smem, bars, xoff, YBLOCK: yoff = copy_index * YBLOCK bar = bars.index(copy_index % num_buffers) mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [xoff, yoff], bar, a_smem.index(copy_index % num_buffers)) - tma.async_copy_global_to_shared(b_desc, [xoff, yoff], bar, b_smem.index(copy_index % num_buffers)) + tma.async_load(a_desc, [xoff, yoff], bar, a_smem.index(copy_index % num_buffers)) + tma.async_load(b_desc, [xoff, yoff], bar, b_smem.index(copy_index % num_buffers)) return copy_index + 1 diff --git a/python/tutorials/gluon/05-wgmma.py b/python/tutorials/gluon/05-wgmma.py index bc0f97681872..ed31690ddeb1 100644 --- a/python/tutorials/gluon/05-wgmma.py +++ b/python/tutorials/gluon/05-wgmma.py @@ -116,7 +116,7 @@ def is_hopper(): # a = a_smem.load(dot_operand_layout) # d = warpgroup_mma(a, b_smem, c, is_async=True) # d = warpgroup_mma_wait(num_outstanding=0, deps=(d, )) -# tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem) +# tma.async_load(a_desc, [0, 0], bar, a_smem) # ``` # %% @@ -138,9 +138,9 @@ def small_mma_kernel(a_desc, b_desc, c_desc, d_desc, # c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_type.shape, c_desc.layout) mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + c_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [0, 0], bar, b_smem) - tma.async_copy_global_to_shared(c_desc, [0, 0], bar, c_smem) + tma.async_load(a_desc, [0, 0], bar, a_smem) + tma.async_load(b_desc, [0, 0], bar, b_smem) + tma.async_load(c_desc, [0, 0], bar, c_smem) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) @@ -351,11 +351,11 @@ def blocked_matmul_kernel(a_desc, b_desc, c_desc, # for k in range(0, K, BLOCK_K): # Load tiles of A and B. mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a_smem) + tma.async_load(a_desc, [off_m, k], bar, a_smem) if TRANSPOSE_B: - tma.async_copy_global_to_shared(b_desc, [off_n, k], bar, b_smem) + tma.async_load(b_desc, [off_n, k], bar, b_smem) else: - tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b_smem) + tma.async_load(b_desc, [k, off_n], bar, b_smem) mbarrier.wait(bar, phase=phase) phase ^= 1 # toggle the parity phase between 0 and 1 @@ -552,8 +552,8 @@ def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.conste b = b_smem.index(index) mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a) - tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b) + tma.async_load(a_desc, [off_m, k], bar, a) + tma.async_load(b_desc, [k, off_n], bar, b) mbarrier.wait(bar, phase=phase) phase ^= 1 diff --git a/python/tutorials/gluon/06-tcgen05.py b/python/tutorials/gluon/06-tcgen05.py index 430abb2679c4..3c06b0ab02ac 100644 --- a/python/tutorials/gluon/06-tcgen05.py +++ b/python/tutorials/gluon/06-tcgen05.py @@ -158,9 +158,9 @@ def small_mma_kernel(a_desc, b_desc, c_desc, d_desc, tmem_block: gl.constexpr, c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_type.shape, c_desc.layout) mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + c_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [0, 0], bar, b_smem) - tma.async_copy_global_to_shared(c_desc, [0, 0], bar, c_smem) + tma.async_load(a_desc, [0, 0], bar, a_smem) + tma.async_load(b_desc, [0, 0], bar, b_smem) + tma.async_load(c_desc, [0, 0], bar, c_smem) mbarrier.wait(bar, phase=0) # Re-using an mbarrier for TMAs and tcgen05_mma can lead to undefined @@ -323,8 +323,8 @@ def blocked_matmul_kernel(a_desc, b_desc, c_desc, TRANSPOSE_B: gl.constexpr, num use_acc = False for k in range(0, K, BLOCK_K): mbarrier.expect(tma_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, k], tma_bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [off_n, k] if TRANSPOSE_B else [k, off_n], tma_bar, b_smem) + tma.async_load(a_desc, [off_m, k], tma_bar, a_smem) + tma.async_load(b_desc, [off_n, k] if TRANSPOSE_B else [k, off_n], tma_bar, b_smem) mbarrier.wait(tma_bar, phase=phase) # We can transpose B by creating a transposed view over tile of B in @@ -515,24 +515,24 @@ def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.conste load_index, load_phase, load_counter = get_and_increment(load_counter) load_ub_bar = load_ub_bars.index(load_index) mbarrier.expect(load_ub_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index)) - tma.async_copy_global_to_shared(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index)) + tma.async_load(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index)) + tma.async_load(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index)) # V1 load_v_bar = load_v_bars.index(load_index) mbarrier.expect(load_v_bar, a_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) + tma.async_load(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) k += BLOCK_K # U2, B2 load_index, load_phase, load_counter = get_and_increment(load_counter) load_ub_bar = load_ub_bars.index(load_index) mbarrier.expect(load_ub_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index)) - tma.async_copy_global_to_shared(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index)) + tma.async_load(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index)) + tma.async_load(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index)) # V2 load_v_bar = load_v_bars.index(load_index) mbarrier.expect(load_v_bar, a_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) + tma.async_load(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) k += BLOCK_K for _ in range(gl.cdiv(K, BLOCK_K) - 2): @@ -553,14 +553,14 @@ def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.conste mbarrier.wait(mma_ub_bars.index(mma_index), mma_phase) load_ub_bar = load_ub_bars.index(load_index) mbarrier.expect(load_ub_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index)) + tma.async_load(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index)) # wait VBi, B(i+2), V(i+2) mbarrier.wait(mma_vb_bars.index(mma_index), mma_phase) - tma.async_copy_global_to_shared(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index)) + tma.async_load(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index)) load_v_bar = load_v_bars.index(load_index) mbarrier.expect(load_v_bar, a_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) + tma.async_load(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) k += BLOCK_K mma_index, mma_phase, mma_counter = get_and_increment(mma_counter) diff --git a/python/tutorials/gluon/07-persistence.py b/python/tutorials/gluon/07-persistence.py index 75b5aa9d6f8b..a1bc50514e1e 100644 --- a/python/tutorials/gluon/07-persistence.py +++ b/python/tutorials/gluon/07-persistence.py @@ -180,8 +180,8 @@ def issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, producer += 1 bar = bars.index(index) mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes, pred=pred) - tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a_bufs.index(index), pred) - tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b_bufs.index(index), pred) + tma.async_load(a_desc, [off_m, k], bar, a_bufs.index(index), pred) + tma.async_load(b_desc, [k, off_n], bar, b_bufs.index(index), pred) return producer @@ -597,8 +597,8 @@ def issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, producer += 1 bar = bars.index(index) mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes, pred=pred) - tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a_bufs.index(index), pred) - tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b_bufs.index(b_index), pred) + tma.async_load(a_desc, [off_m, k], bar, a_bufs.index(index), pred) + tma.async_load(b_desc, [k, off_n], bar, b_bufs.index(b_index), pred) return producer diff --git a/python/tutorials/gluon/08-warp-specialization.py b/python/tutorials/gluon/08-warp-specialization.py index ddc851aaeac9..13b4c3cee2a0 100644 --- a/python/tutorials/gluon/08-warp-specialization.py +++ b/python/tutorials/gluon/08-warp-specialization.py @@ -115,7 +115,7 @@ def is_blackwell(): # # mbarrier.wait(bar, phase=0) # fence_async_shared() -# tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) +# tma.async_load(desc, [0, 0], bar, smem) # ``` # # A fence is needed somewhere between the shared memory load and the TMA load. @@ -149,8 +149,8 @@ def load_partition(descs, barriers, buffers, xoff, numel, YBLOCK: gl.constexpr): # signal the operand buffers as ready when they complete. yoff = i * YBLOCK mbarrier.expect(load_ready_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [xoff, yoff], load_ready_bar, a_buf) - tma.async_copy_global_to_shared(b_desc, [xoff, yoff], load_ready_bar, b_buf) + tma.async_load(a_desc, [xoff, yoff], load_ready_bar, a_buf) + tma.async_load(b_desc, [xoff, yoff], load_ready_bar, b_buf) @gluon.jit @@ -444,8 +444,8 @@ def matmul_load_partition(p, SchedulerImpl: gl.constexpr): bar = ready_bars.index(state.index) mbarrier.wait(empty_bars.index(state.index), state.phase) mbarrier.expect(bar, p.a_desc.block_type.nbytes + p.b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index)) - tma.async_copy_global_to_shared(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index)) + tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index)) + tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index)) state = state.next() diff --git a/python/tutorials/gluon/09-tma-gather-scatter.py b/python/tutorials/gluon/09-tma-gather-scatter.py index 601278beb25d..2479df2ee9f4 100644 --- a/python/tutorials/gluon/09-tma-gather-scatter.py +++ b/python/tutorials/gluon/09-tma-gather-scatter.py @@ -471,7 +471,7 @@ def issue_loads(producer, X_desc, W_desc, X_gather_indx_ptr, off_m, off_n, k, ba # The W tensor tile is loaded using a regular `async_copy_global_to_shared`. mbarrier.expect(bar, W_desc.block_type.nbytes + BLOCK_M * X_desc.block_type.nbytes) tma.async_gather(X_desc, offs_x_m, k, bar, x_bufs.index(index), pred) - tma.async_copy_global_to_shared(W_desc, [k, off_n], bar, w_bufs.index(index), pred) + tma.async_load(W_desc, [k, off_n], bar, w_bufs.index(index), pred) return producer diff --git a/python/tutorials/gluon/10-tcgen05-copy.py b/python/tutorials/gluon/10-tcgen05-copy.py index 2dd4aec7b53c..f246ad2a5b4d 100644 --- a/python/tutorials/gluon/10-tcgen05-copy.py +++ b/python/tutorials/gluon/10-tcgen05-copy.py @@ -255,15 +255,15 @@ def matmul_accumulate_load_partition(p): # Issue the async TMA load for the C tile. mbarrier.wait(p.c_empty_bar, c_phase) mbarrier.expect(p.c_ready_bar, p.c_desc.block_type.nbytes) - tma.async_copy_global_to_shared(p.c_desc, [off_m, off_n], p.c_ready_bar, p.c_buf) + tma.async_load(p.c_desc, [off_m, off_n], p.c_ready_bar, p.c_buf) c_phase ^= 1 # Inner loop loads. for k in range(0, K, BLOCK_K): bar = p.load_ready_bars.index(state.index) mbarrier.wait(p.load_empty_bars.index(state.index), state.phase) mbarrier.expect(bar, p.a_desc.block_type.nbytes + p.b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index)) - tma.async_copy_global_to_shared(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index)) + tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index)) + tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index)) state = state.next() @@ -418,8 +418,8 @@ def test_matmul_accumulate(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, dtyp # which can be implicitly pipelined with the `tcgen05_mma_scaled` instruction: # # ```python -# tma.async_copy_global_to_shared(a_scale_desc, ..., bar, a_scale_buf) -# tma.async_copy_global_to_shared(b_scale_desc, ..., bar, b_scale_buf) +# tma.async_load(a_scale_desc, ..., bar, a_scale_buf) +# tma.async_load(b_scale_desc, ..., bar, b_scale_buf) # mbarrier.wait(bar, phase) # # tcgen05_copy(a_scale_buf, a_scale_tmem) diff --git a/python/tutorials/gluon/11-tcgen05-mma-scaled.py b/python/tutorials/gluon/11-tcgen05-mma-scaled.py index 573412470470..7c206c02d7be 100644 --- a/python/tutorials/gluon/11-tcgen05-mma-scaled.py +++ b/python/tutorials/gluon/11-tcgen05-mma-scaled.py @@ -174,8 +174,8 @@ def simple_mma_scaled_kernel(a_desc, b_desc, c_desc, a_scale_ptr, a_scale_stride # Load the A and B tiles. mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_smem) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_smem) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_smem) mbarrier.wait(bar, phase) # Load the scales. We must always feed `b_scales` into `tcgen05_mma_scaled` @@ -481,8 +481,8 @@ def mma_scaled_contig_kernel(a_desc, b_desc, c_desc, a_scale_ptr, b_scale_ptr, V off_k_b = k // B_ELEM_PER_BYTE mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_smem) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_smem) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_smem) mbarrier.wait(bar, phase) # ======= End unchanged code from `simple_mma_scaled_kernel` ======= @@ -725,10 +725,10 @@ def mma_scaled_packed_block_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + b_scale_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_smem) - tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_smem) - tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_smem) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_smem) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_smem) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_smem) + tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_smem) mbarrier.wait(bar, phase) # We know the destination 2D layout of the scales required to store them @@ -1005,10 +1005,10 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + b_scale_desc.block_type.nbytes) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_smem) - tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_smem) - tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_smem) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_smem) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_smem) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_smem) + tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_smem) mbarrier.wait(bar, phase) # ======= End unchanged code from `mma_scaled_packed_block_kernel` ======= @@ -1186,12 +1186,10 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + b_scale_desc.block_type.nbytes, pred) - tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred) - tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred) - tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, - a_scale_bufs.index(index), pred) - tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, - b_scale_bufs.index(index), pred) + tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred) + tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred) + tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred) + tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred) return producer.next(pred) diff --git a/python/tutorials/gluon/13-conv-im2col.py b/python/tutorials/gluon/13-conv-im2col.py index c61b49e53954..7ef9b0ad1077 100644 --- a/python/tutorials/gluon/13-conv-im2col.py +++ b/python/tutorials/gluon/13-conv-im2col.py @@ -246,7 +246,7 @@ def tma_im2col_kernel(in_desc, out_desc, coord_n: int, coord_h: int, coord_w: in mbarrier.expect(bar, in_desc.block_type.nbytes) # TMA im2col load with specified coordinates and offsets - tma.async_copy_global_to_shared_im2col( + tma.async_load_im2col( in_desc, [coord_n, coord_h, coord_w, coord_c], [offset_h, offset_w], @@ -884,12 +884,12 @@ def test_tma_im2col_multi_batch_padded(): # for ci_block in range(Ci // BLOCK_K): # # Input tile via TMA im2col: # # input[batch, out_y*stride+r-pad, out_x*stride+s-pad, ci_block*BLOCK_K:...] -# tma.async_copy_global_to_shared_im2col(...) +# tma.async_load_im2col(...) # A_tile = a_smem # [BLOCK_M, BLOCK_K] # # # Weight tile via standard TMA: # # weight[co_start:..., r, s, ci_block*BLOCK_K:...] -# tma.async_copy_global_to_shared(...) +# tma.async_load(...) # B_tile = b_smem # [BLOCK_N, BLOCK_K] # # # Matrix multiply-accumulate (details omitted here) @@ -1004,7 +1004,7 @@ def conv2d_im2col_kernel( # TMA applies offsets to the spatial coords, so start is: # [batch_id, out_y*stride_h-pad_h+r, out_x*stride_w-pad_w+s, ci_block*BLOCK_K] mbarrier.expect(tma_bar, in_desc.block_type.nbytes + weight_desc.block_type.nbytes) - tma.async_copy_global_to_shared_im2col( + tma.async_load_im2col( in_desc, [batch_id, out_y * stride_h - pad_h, out_x * stride_w - pad_w, ci_block * BLOCK_K], [r.to(tl.int16), s.to(tl.int16)], @@ -1017,7 +1017,7 @@ def conv2d_im2col_kernel( # group, those weight elements are multiplied by zero-filled input channels (TMA # zero-fills input channels past Ci), so the result is still correct. k_offset = r * S * Ci + s * Ci + ci_block * BLOCK_K - tma.async_copy_global_to_shared(weight_desc, [pid_n * BLOCK_N, k_offset], tma_bar, b_smem) + tma.async_load(weight_desc, [pid_n * BLOCK_N, k_offset], tma_bar, b_smem) mbarrier.wait(tma_bar, phase=phase) # acc += A @ B^T diff --git a/python/tutorials/gluon/14-multicta.py b/python/tutorials/gluon/14-multicta.py index 2e4761024acd..b164611b12bf 100644 --- a/python/tutorials/gluon/14-multicta.py +++ b/python/tutorials/gluon/14-multicta.py @@ -367,8 +367,8 @@ def two_cta_tcgen05_kernel(a_desc, b_desc, c_desc): mbarrier.init(mma_bar, count=1) mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, 0], tma_bar, smem_a) - tma.async_copy_global_to_shared(b_desc, [0, 0], tma_bar, smem_b) + tma.async_load(a_desc, [0, 0], tma_bar, smem_a) + tma.async_load(b_desc, [0, 0], tma_bar, smem_b) mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b]) mbarrier.invalidate(tma_bar) @@ -495,7 +495,7 @@ def tma_multicast_copy_kernel(in_desc, out_desc): mbarrier.init(bar, count=1) mbarrier.expect(bar, in_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem, multicast=True) + tma.async_load(in_desc, [0, 0], bar, smem, multicast=True) mbarrier.wait(bar, phase=0, deps=[smem]) tma.async_copy_shared_to_global(out_desc, [0, 0], smem) @@ -559,8 +559,8 @@ def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_ for k in range(NUM_K_TILES): mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(a_desc, [0, k * block_k], tma_bar, smem_a, multicast=True) - tma.async_copy_global_to_shared(b_desc, [k * block_k, 0], tma_bar, smem_b, multicast=True) + tma.async_load(a_desc, [0, k * block_k], tma_bar, smem_a, multicast=True) + tma.async_load(b_desc, [k * block_k, 0], tma_bar, smem_b, multicast=True) mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b]) phase_tma ^= 1 @@ -867,8 +867,8 @@ def matmul_load_partition(p): mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred) bar = p.load_ready_bars.index(state.index) mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True) - tma.async_copy_global_to_shared(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True) + tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True) + tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True) state = state.next() scheduler = scheduler.step(i) i += 1 diff --git a/third_party/proton/tutorials/intra_kernel/example_dsl.py b/third_party/proton/tutorials/intra_kernel/example_dsl.py index 1db8a5eb5b00..3a85d1c98ff6 100644 --- a/third_party/proton/tutorials/intra_kernel/example_dsl.py +++ b/third_party/proton/tutorials/intra_kernel/example_dsl.py @@ -234,8 +234,8 @@ def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.conste mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) with pl.scope("tma_loads_issue"): - tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a) - tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b) + tma.async_load(a_desc, [off_m, k], bar, a) + tma.async_load(b_desc, [k, off_n], bar, b) with pl.scope("tma_loads_wait"): mbarrier.wait(bar, phase=phase)