Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/examples/gluon/01-attention-forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_desc_channel(desc, num_buffers: gl.constexpr, num_consumers: gl.constexp
@gluon.jit
def issue_async_tma_load(smem, bar, desc, offset):
mbarrier.expect(bar, desc.block_type.nbytes)
tma.async_copy_global_to_shared(desc, [offset, 0], bar, smem)
tma.async_load(desc, [offset, 0], bar, smem)


# ===-----------------------------------------------------------------------===#
Expand Down
4 changes: 2 additions & 2 deletions python/examples/gluon/02-convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def load_partition(p):
mbarrier.expect(ready_bar, p.in_desc.block_type.nbytes + p.weight_desc.block_type.nbytes)

# Input tile via TMA im2col: channel offset iter_ci*BLOCK_K (OOB channels zero-filled)
tma.async_copy_global_to_shared_im2col(
tma.async_load_im2col(
p.in_desc,
[
batch_id, out_y * config.stride_h - config.pad_h, out_x * config.stride_w - config.pad_w,
Expand All @@ -193,7 +193,7 @@ def load_partition(p):
# ci block bleeds into the next (r,s) group, those weight elements are multiplied
# by zero-filled input channels, so the result is still correct.
k_offset = (iter_r * config.S + iter_s) * config.Ci + iter_ci * BLOCK_K
tma.async_copy_global_to_shared(
tma.async_load(
p.weight_desc,
[prog.pid_n * BLOCK_N, k_offset],
ready_bar,
Expand Down
4 changes: 2 additions & 2 deletions python/examples/gluon/03-matmul-multicta.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ def matmul_load_partition(p):
mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred)
bar = p.load_ready_bars.index(state.index)
mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True)
tma.async_copy_global_to_shared(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True)
tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True)
tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True)
state = state.next()
scheduler = scheduler.step(i)
i += 1
Expand Down
11 changes: 5 additions & 6 deletions python/examples/gluon/04-2cta-block-scale-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,11 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale
mbarrier.expect(
bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta,
pred)
tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred)
tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred)
tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar,
a_scale_bufs.index(index), pred)
tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar,
b_scale_bufs.index(index), pred, multicast=multicast_b_scale)
tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred)
tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred)
tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred)
tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred,
multicast=multicast_b_scale)
return producer.next(pred)


Expand Down
52 changes: 26 additions & 26 deletions python/test/gluon/test_consan.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
tma.async_load(input_desc, [0, 0], bar, smem)
mbarrier.wait(bar, 0, pred=(not FAILURE), deps=[smem])
val = smem.load(blocked_layout)
mbarrier.wait(bar, 0, pred=FAILURE, deps=[smem])
Expand Down Expand Up @@ -231,7 +231,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
tma.async_load(input_desc, [0, 0], bar, smem, multicast=True)
mbarrier.wait(bar, 0, pred=(not FAILURE), deps=[smem])
val = smem.load(blocked_layout)
mbarrier.wait(bar, 0, pred=FAILURE, deps=[smem])
Expand Down Expand Up @@ -313,7 +313,7 @@ def kernel(input_desc, out):
for phase in ttgl.static_range(2):
mbarrier.expect(bar, input_desc.nbytes_per_cta)
ttgl.barrier(cluster=True)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
tma.async_load(input_desc, [0, 0], bar, smem, multicast=True)
mbarrier.wait(bar, phase % 2, deps=[smem])
ttgl.barrier(cluster=True)
val = smem.load(blocked_layout)
Expand Down Expand Up @@ -357,7 +357,7 @@ def kernel(input_desc, out):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
tma.async_load(input_desc, [0, 0], bar, smem, multicast=True)
ttgl.barrier(cluster=True)
smem.store(ttgl.full([XBLOCK, XBLOCK], 1, ttgl.float16, blocked_layout))
mbarrier.wait(bar, 0, deps=[smem])
Expand Down Expand Up @@ -405,8 +405,8 @@ def kernel(a_desc, b_desc, out, FAILURE: ttgl.constexpr):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem)
tma.async_copy_global_to_shared(b_desc, [0, 0], bar, b_smem)
tma.async_load(a_desc, [0, 0], bar, a_smem)
tma.async_load(b_desc, [0, 0], bar, b_smem)
mbarrier.wait(bar, 0, pred=(not FAILURE), deps=[a_smem, b_smem])
val = a_smem.load(blocked_layout)
val = val + b_smem.load(blocked_layout)
Expand Down Expand Up @@ -453,7 +453,7 @@ def kernel(input_desc, out, EXPECT_DELTA: ttgl.constexpr):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, input_desc.nbytes_per_cta + EXPECT_DELTA)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
tma.async_load(input_desc, [0, 0], bar, smem)
mbarrier.wait(bar, 0, deps=[smem])
val = smem.load(blocked_layout)
mbarrier.invalidate(bar)
Expand Down Expand Up @@ -499,8 +499,8 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr):
mbarrier.init(bar.index(1), count=1)
mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta)
mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(0), smem.index(0))
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smem.index(1))
tma.async_load(input_desc, [0, 0], bar.index(0), smem.index(0))
tma.async_load(input_desc, [0, 0], bar.index(1), smem.index(1))

mbarrier.wait(bar.index(0), 0, deps=[smem.index(0)])
if not FAILURE:
Expand Down Expand Up @@ -557,9 +557,9 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr):
mbarrier.init(bar.index(1), count=1)

mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(0), smem.index(0))
tma.async_load(input_desc, [0, 0], bar.index(0), smem.index(0))
mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smem.index(1))
tma.async_load(input_desc, [0, 0], bar.index(1), smem.index(1))

mbarrier.wait(bar.index(1), 0)

Expand Down Expand Up @@ -742,7 +742,7 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt

if MEM_ACCESS_KIND == "tma_cp":
mbarrier.expect(tma_bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], tma_bar, smemA)
tma.async_load(input_desc, [0, 0], tma_bar, smemA)
mbarrier.wait(tma_bar, 0)
mbarrier.invalidate(tma_bar)
elif MEM_ACCESS_KIND == "local_store":
Expand Down Expand Up @@ -1052,18 +1052,18 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr):

# ins_id = 0
mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))
tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))

mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
ins_id = inc_mod(ins_id, num_buffers)

# ins_id = 1
mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))
tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))

mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
ins_id = inc_mod(ins_id, num_buffers)

mbarrier.wait(barLoadA.index(ext_id), phase)
Expand All @@ -1078,10 +1078,10 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr):
for i in range(ub):
if i < ub - 2:
mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))
tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))

mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
ins_id = inc_mod(ins_id, num_buffers)

if i < ub - 1:
Expand Down Expand Up @@ -1169,8 +1169,8 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr):
for k in range(num_k_tiles):
offs_k = k * XBLOCK
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, offs_k], tma_bar, smemA, multicast=True)
tma.async_copy_global_to_shared(b_desc, [offs_k, 0], tma_bar, smemB, multicast=True)
tma.async_load(a_desc, [0, offs_k], tma_bar, smemA, multicast=True)
tma.async_load(b_desc, [offs_k, 0], tma_bar, smemB, multicast=True)
if not FAILURE:
mbarrier.wait(tma_bar, phase_tma, deps=[smemA, smemB])
blackwell.tcgen05_mma(smemA, smemB, acc, use_acc=k != 0, multicast=True, mbarriers=[mma_bar])
Expand Down Expand Up @@ -1239,21 +1239,21 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr):

# ins_id = 0
mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))
tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))

mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
ins_id = inc_mod(ins_id, num_buffers)

# ins_id = 1
ub = 10
for i in range(ub):
if i < ub - 1:
mbarrier.expect(barLoadA.index(ins_id), a_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))
tma.async_load(a_desc, [0, 0], barLoadA.index(ins_id), smemA.index(ins_id))

mbarrier.expect(barLoadB.index(ins_id), b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
tma.async_load(b_desc, [0, 0], barLoadB.index(ins_id), smemB.index(ins_id))
ins_id = inc_mod(ins_id, num_buffers)

mbarrier.wait(barLoadA.index(ext_id), phase)
Expand Down Expand Up @@ -2158,13 +2158,13 @@ def test_deadlock_exempt_when_tma_signals(device, run_wrapper, monkeypatch, num_
@gluon.jit
def ws_default(input_desc, smem, bar):
mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(0), smem.index(0))
tma.async_load(input_desc, [0, 0], bar.index(0), smem.index(0))
mbarrier.wait(bar.index(0), phase=0)

@gluon.jit
def ws_1(input_desc, smem, bar):
mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smem.index(1))
tma.async_load(input_desc, [0, 0], bar.index(1), smem.index(1))
mbarrier.wait(bar.index(1), phase=0)

@gluon.jit
Expand Down
34 changes: 17 additions & 17 deletions python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def tma_im2col_kernel(in_desc, out_desc):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, in_desc.block_type.nbytes)
tma.async_copy_global_to_shared_im2col(in_desc, [0, 0, 0, 0], [0, 0], bar, smem)
tma.async_load_im2col(in_desc, [0, 0, 0, 0], [0, 0], bar, smem)
mbarrier.wait(bar, phase=0)
mbarrier.invalidate(bar)
tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
Expand Down Expand Up @@ -171,7 +171,7 @@ def tma_round_f32_to_tf32_kernel(in_desc, out_desc):
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, in_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem)
tma.async_load(in_desc, [0, 0], bar, smem)
mbarrier.wait(bar, phase=0, deps=[smem])
mbarrier.invalidate(bar)
tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
Expand Down Expand Up @@ -220,7 +220,7 @@ def tma_multicast_copy_kernel(in_desc, out_desc):
mbarrier.init(bar, count=1)

mbarrier.expect(bar, in_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem, multicast=True)
tma.async_load(in_desc, [0, 0], bar, smem, multicast=True)
mbarrier.wait(bar, phase=0, deps=[smem])

tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
Expand Down Expand Up @@ -356,8 +356,8 @@ def tcgen05_mma_multicast_commit_kernel(a_desc, b_desc, out_ptrs, BLOCK_M: ttgl.
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], True))

mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, 0], tma_bar, smem_a, multicast=True)
tma.async_copy_global_to_shared(b_desc, [0, 0], tma_bar, smem_b, multicast=True)
tma.async_load(a_desc, [0, 0], tma_bar, smem_a, multicast=True)
tma.async_load(b_desc, [0, 0], tma_bar, smem_b, multicast=True)
mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b])
mbarrier.invalidate(tma_bar)

Expand Down Expand Up @@ -489,8 +489,8 @@ def tcgen05_mma_scaled_direct_multicast_kernel(a_desc, b_desc, out_ptr, BLOCK_M:
for k in range(NUM_K_TILES):
offs_k = k * BLOCK_K
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(a_desc, [0, offs_k], tma_bar, smem_a, multicast=True)
tma.async_copy_global_to_shared(b_desc, [offs_k, 0], tma_bar, smem_b, multicast=True)
tma.async_load(a_desc, [0, offs_k], tma_bar, smem_a, multicast=True)
tma.async_load(b_desc, [offs_k, 0], tma_bar, smem_b, multicast=True)
mbarrier.wait(tma_bar, phase_tma, deps=[smem_a, smem_b])
tcgen05_mma_scaled(smem_a, smem_b, acc, a_scale, b_scale, "e5m2", "e5m2", use_acc=k != 0, multicast=True,
mbarriers=[mma_bar])
Expand Down Expand Up @@ -586,7 +586,7 @@ def tma_device_load_kernel(input_ptr, output_ptr, XBLOCK: ttgl.constexpr, smem_l
mbarrier.init(bar, count=1)

mbarrier.expect(bar, input_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
tma.async_load(input_desc, [0, 0], bar, smem)
mbarrier.wait(bar, 0)
mbarrier.invalidate(bar)

Expand Down Expand Up @@ -774,8 +774,8 @@ def tma_mma_shared_inputs_kernel(a_desc, b_desc, out_ptr, out_desc, gather_idx_p
if use_gather_scatter:
blackwell_tma.async_gather(a_desc, gather_offsets, k * BLOCK_K, tma_bar, smem_a, multicast=multicast)
else:
tma.async_copy_global_to_shared(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast)
tma.async_copy_global_to_shared(b_desc, [k * BLOCK_K, 0], tma_bar, smem_b, multicast=multicast)
tma.async_load(a_desc, [0, k * BLOCK_K], tma_bar, smem_a, multicast=multicast)
tma.async_load(b_desc, [k * BLOCK_K, 0], tma_bar, smem_b, multicast=multicast)
mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b])
phase_tma ^= 1

Expand Down Expand Up @@ -1758,7 +1758,7 @@ def kernel(in_desc, out_desc):
mbarrier.init(bar, count=1)

mbarrier.expect(bar, in_desc.nbytes_per_cta)
tma.async_copy_global_to_shared(in_desc, [0, 0], bar, smem_slice1)
tma.async_load(in_desc, [0, 0], bar, smem_slice1)
mbarrier.wait(bar, phase=0)

blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0])
Expand Down Expand Up @@ -3830,12 +3830,12 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale
EXPECTED_BYTES: ttgl.constexpr = (a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta +
b_scale_desc.nbytes_per_cta)
mbarrier.expect(tma_bar, EXPECTED_BYTES)
tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], tma_bar, a_smem, multicast=multicast)
tma.async_copy_global_to_shared(b_desc, [off_n, off_k_b], tma_bar, b_smem, multicast=multicast)
tma.async_copy_global_to_shared(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], tma_bar, a_scale_smem,
multicast=multicast)
tma.async_copy_global_to_shared(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], tma_bar, b_scale_smem,
multicast=multicast)
tma.async_load(a_desc, [off_m, off_k_a], tma_bar, a_smem, multicast=multicast)
tma.async_load(b_desc, [off_n, off_k_b], tma_bar, b_smem, multicast=multicast)
tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], tma_bar, a_scale_smem,
multicast=multicast)
tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], tma_bar, b_scale_smem,
multicast=multicast)
mbarrier.wait(tma_bar, phase_tma, deps=[a_smem, b_smem, a_scale_smem, b_scale_smem])
phase_tma ^= 1

Expand Down
4 changes: 2 additions & 2 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)

tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
tma.async_load(input_desc, [0, 0], bar, smem)
ttgl.static_assert(input_desc.block_type.nbytes == XBLOCK * XBLOCK * 2)
mbarrier.expect(bar, input_desc.block_type.nbytes)
mbarrier.wait(bar, 0)
Expand Down Expand Up @@ -3849,7 +3849,7 @@ def nv_tma_descriptor_load_kernel(input_ptr):
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float32.primitive_bitwidth // 8)
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
tma.async_load(input_desc, [0, 0], bar, smem)

ptr = MockTensor(ttgl.float32)
module = run_parser(nv_tma_descriptor_load_kernel, *make_args(ptr), target)
Expand Down
Loading
Loading