Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
The `ttg.memdesc_reinterpret` operation reinterprets a memory descriptor
as one with a different shape and element type. Because memory descriptors
lack strides, this operation is only valid if the original memory descriptor
is contiguous.
is contiguous. Reinterpretation of subviews is not allowed; reinterpret the
parent descriptor and then take a subview of the reinterpreted descriptor
instead.
}];

let arguments = (ins TTG_MemDescType:$src);
Expand Down
34 changes: 27 additions & 7 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,17 +656,37 @@ OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) {
LogicalResult MemDescReinterpretOp::verify() {
auto srcTy = getSrc().getType();
auto dstTy = getResult().getType();
auto kBlock = StringAttr::get(getContext(), "block");
auto getNumBroadcastCTADims = [kBlock](MemDescType ty) {
if (srcTy.getMemorySpace() != dstTy.getMemorySpace())
return emitError("source and result must have the same memory space");
if (srcTy.getMutableMemory() != dstTy.getMutableMemory())
return emitError("source and result must have the same mutability");
auto isSubview = [](MemDescType ty) {
auto rank = cast<LayoutEncodingTrait>(ty.getEncoding()).getRank();
return ty.getShape().take_back(rank) != ty.getAllocShape().take_back(rank);
};
if (isSubview(srcTy) || isSubview(dstTy))
return emitError("source and result must not be subviews; reinterpret the "
"parent descriptor and then take a subview");
assert((isa<SharedMemorySpaceAttr, nvidia_gpu::TensorMemorySpaceAttr>(
srcTy.getMemorySpace()) &&
"expected shared or tensor memory"));
auto getViewNumBits = [](MemDescType ty) {
auto rank = cast<LayoutEncodingTrait>(ty.getEncoding()).getRank();
auto layout =
toLinearLayout(ty.getAllocShape().take_back(rank), ty.getEncoding());
auto freeVariableMask = layout.getFreeVariableMasks().lookup(kBlock);
return llvm::popcount<uint32_t>(freeVariableMask);
// Shared memory is allocated by offset and TMEM is allocated by column; the
Comment thread
lezcano marked this conversation as resolved.
// other physical dimensions do not increase or decrease the allocation.
auto *ctx = ty.getContext();
bool isSharedMemory = isa<SharedMemorySpaceAttr>(ty.getMemorySpace());
auto dim = StringAttr::get(ctx, isSharedMemory ? "offset" : "col");
return layout.getInDimSize(dim) * ty.getElementTypeBitWidth();
};
if (getNumBroadcastCTADims(srcTy) != getNumBroadcastCTADims(dstTy))
return emitError(
"source and result must have the same number of broadcast CTA dims");
auto srcNumBits = getViewNumBits(srcTy);
auto dstNumBits = getViewNumBits(dstTy);
if (srcNumBits != dstNumBits)
return emitError() << "source and result must have the same logical "
"storage size ("
<< srcNumBits << " vs " << dstNumBits << ")";
return success();
}

Expand Down
17 changes: 8 additions & 9 deletions python/examples/gluon/01-attention-forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,26 +383,25 @@ def get_loop_bounds(self, STAGE: gl.constexpr):

@gluon.jit
def _borrow_s_as_p(config, s_tmem):
p_tmem = s_tmem.slice(0, config.BLOCK_N // 2)
return p_tmem._reinterpret(config.dtype, config.qk_shape, config.p_tmem_layout)
p_tmem = s_tmem._reinterpret(config.dtype, [config.SPLIT_M, 2 * config.BLOCK_N], config.p_tmem_layout)
return p_tmem.slice(0, config.BLOCK_N)


@gluon.jit
def _borrow_s_as_alpha(config, s_tmem):
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M_PER_CTA, 1], col_stride=1,
cga_layout=config.CGA_LAYOUT, two_ctas=gl.num_ctas() > 1)
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)
alpha_tmem = s_tmem._reinterpret(layout=alpha_layout)
return alpha_tmem.slice(config.BLOCK_N // 2, 1)


@gluon.jit
def _borrow_s_for_epilogue(config, s_tmem):
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M_PER_CTA, 1], col_stride=1, cga_layout=config.CGA_LAYOUT,
two_ctas=gl.num_ctas() > 1)
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
s_tmem = s_tmem._reinterpret(layout=layout)
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
Comment thread
lezcano marked this conversation as resolved.
return m_i_tmem, l_i_tmem


Expand Down Expand Up @@ -880,7 +879,7 @@ def attention_kernel( #

q_chnl = get_desc_channel(desc_q, num_buffers=2)
kv_chnl = get_desc_channel(desc_k, num_buffers=config.num_kv_buffers)
v_mem = kv_chnl.mem._reinterpret(desc_v.dtype, [config.num_kv_buffers] + desc_v.block_type.shape, desc_v.layout)
v_mem = kv_chnl.mem._reinterpret(layout=desc_v.layout)
o_chnl = TensorMemoryChannel.alloc(config.o_shape, gl.float32, config.o_tmem_layout, num_buffers=2,
producer_two_ctas=gl.num_ctas() > 1)
epi_chnl = SharedMemoryChannel.alloc(config.o_shape, config.dtype, gl.constexpr(desc_o.layout), num_buffers=2)
Expand Down
17 changes: 10 additions & 7 deletions python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,7 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
tcgen05_commit(barrier)
mbarrier.wait(barrier, phase=0)
tmem_alias: ttgl.constexpr = TensorMemoryLayout((num_rows, num_cols), col_stride=1)
tmem = tmem._reinterpret(ttgl.int8, (num_rows, num_cols), tmem_alias)
tmem = tmem._reinterpret(shape=(num_rows, num_cols), layout=tmem_alias)
value = tmem.load(blocked)
ttgl.store(ttgl.set_auto_layout(out_ptrs, blocked), value)

Expand Down Expand Up @@ -1567,17 +1567,19 @@ def kernel(s_ptr, out_ptr):
s_tmem.store(s)
o_tmem.store(s)

p_tmem = s_tmem.slice(0, N // 2)._reinterpret(ttgl.float16, [BLOCK_M, N], tmem_layout)
p_tmem_parent = s_tmem._reinterpret(ttgl.float16, [BLOCK_M, 2 * N], tmem_layout)
p_tmem = p_tmem_parent.slice(0, N)
p_tmem.store(ttgl.full((BLOCK_M, N), 0.0, dtype=ttgl.float16, layout=layout))

d1_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, 2), col_stride=1)

m_tmem = s_tmem.slice(N // 4, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout)
d1_tmem_parent = s_tmem._reinterpret(layout=d1_tmem_layout)
m_tmem = d1_tmem_parent.slice(N // 2, 2)
d1_layout: ttgl.constexpr = m_tmem.get_reg_layout()
m_tmem.store(ttgl.full((BLOCK_M, 2), 2.0, dtype=ttgl.float32, layout=d1_layout))
l_tmem = s_tmem.slice(N // 4 + 2, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout)
l_tmem = d1_tmem_parent.slice(N // 2 + 4, 2)
l_tmem.store(ttgl.full((BLOCK_M, 2), 3.0, dtype=ttgl.float32, layout=d1_layout))
a_tmem = s_tmem.slice(N // 4 + 4, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout)
a_tmem = d1_tmem_parent.slice(N // 2 + 8, 2)
a_tmem.store(ttgl.full((BLOCK_M, 2), 4.0, dtype=ttgl.float32, layout=d1_layout))

s = s_tmem.load()
Expand Down Expand Up @@ -1616,7 +1618,7 @@ def kernel(s_ptr, out_ptr):
# TMEM[0:16] = [s0, s1]
# TMEM[16:32] = [s2, s3]
#
# Thus slicing S at N//4 will obtain an offset to the beginning of s1.
# Thus the narrow parent view is sliced at offsets that map back to s1.
out_ref[:, 32:34] = 2.0
out_ref[:, 34:36] = 3.0
out_ref[:, 36:38] = 4.0
Expand Down Expand Up @@ -1727,7 +1729,8 @@ def kernel(in_ptr, out_ptr):
smem_layout_2d: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])
smem = ttgl.allocate_shared_memory(ttgl.int8, [BLOCK], smem_layout_1d)
smem_slice0 = smem.slice(0, SPLIT_BLOCK)
smem_slice1 = smem.slice(SPLIT_BLOCK, SPLIT_BLOCK)._reinterpret(ttgl.int32, [XBLOCK, YBLOCK], smem_layout_2d)
smem_i32 = smem._reinterpret(ttgl.int32, [2 * XBLOCK, YBLOCK], smem_layout_2d)
smem_slice1 = smem_i32.slice(XBLOCK, XBLOCK, dim=0)

offs = ttgl.arange(0, XBLOCK)[:, None] * YBLOCK + ttgl.arange(0, YBLOCK)[None, :]
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, NUM_THREADS], [1, 4], [1, 0])
Expand Down
4 changes: 2 additions & 2 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def shared_memory_cast_kernel():
smem = ttgl.allocate_shared_memory(ttgl.float16, [32, 1, 4, 64], layout_b)
smem.reshape((128, 64))

smem._reinterpret(ttgl.int8, [1024], ttgl.SwizzledSharedLayout(1, 1, 1, [0]))
smem._reinterpret(ttgl.int8, [16384], ttgl.SwizzledSharedLayout(1, 1, 1, [0]))


@pytest.mark.parametrize("target", ALL_TARGETS)
Expand All @@ -468,7 +468,7 @@ def test_shared_memory_cast(target):
tt.call @test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False__NVMMALAS128_256ASMD(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) -> ()
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable>
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16384xi8, #shared4, #smem, mutable>
tt.return
}
tt.func private @test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False__NVMMALAS128_256ASMD(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) attributes {noinline = true} {
Expand Down
15 changes: 8 additions & 7 deletions python/triton/experimental/gluon/language/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,21 +485,22 @@ def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descr
return _semantic.memdesc_reshape(self, shape)

@builtin
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
def _reinterpret(self, dtype=None, shape=None, layout=None,
_semantic: GluonSemantic = None) -> shared_memory_descriptor:
"""
Reinterpret the shared memory descriptor as a different dtype, shape, or layout.

Args:
dtype (dtype): The new data type.
shape (List[int]): The new shape.
layout (SharedLayout): The new layout.
dtype (dtype): The new data type. Defaults to the descriptor dtype.
shape (List[int]): The new shape. Defaults to the descriptor shape.
layout (SharedLayout): The new layout. Defaults to the descriptor layout.

Returns:
shared_memory_descriptor: Descriptor with updated type and layout.
"""
dtype = _unwrap_if_constexpr(dtype)
shape = [_unwrap_if_constexpr(s) for s in shape]
layout = _unwrap_if_constexpr(layout)
dtype = self.dtype if dtype is None else _unwrap_if_constexpr(dtype)
shape = self.shape if shape is None else [_unwrap_if_constexpr(s) for s in shape]
layout = self.layout if layout is None else _unwrap_if_constexpr(layout)

return _semantic.memdesc_reinterpret(self, dtype, shape, layout)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,21 +395,22 @@ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descrip
return ret

@builtin
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
def _reinterpret(self, dtype=None, shape=None, layout=None,
_semantic: GluonSemantic = None) -> tensor_memory_descriptor:
"""
Reinterpret tensor memory descriptor with a new dtype, shape, and layout.

Args:
dtype (dtype): The new data type.
shape (Sequence[int]): The new shape.
layout (TensorMemoryLayout): The new layout.
dtype (dtype): The new data type. Defaults to the descriptor dtype.
shape (Sequence[int]): The new shape. Defaults to the descriptor shape.
layout (TensorMemoryLayout): The new layout. Defaults to the descriptor layout.

Returns:
tensor_memory_descriptor: Descriptor with updated type and layout.
"""
dtype = _unwrap_if_constexpr(dtype)
shape = [_unwrap_if_constexpr(s) for s in shape]
layout = _unwrap_if_constexpr(layout)
dtype = self.dtype if dtype is None else _unwrap_if_constexpr(dtype)
shape = self.shape if shape is None else [_unwrap_if_constexpr(s) for s in shape]
layout = self.layout if layout is None else _unwrap_if_constexpr(layout)

ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
Expand Down
52 changes: 38 additions & 14 deletions python/tutorials/gluon/07-persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def wait_num_outstanding(self, num_outstanding: gl.constexpr):

# Take the result and reset the accumulator.
@gluon.jit
def take_result(self):
def take_result(self, splitn: gl.constexpr = False):
return self.acc, WGMMA(self.acc, gl.to_tensor(False))


Expand Down Expand Up @@ -149,9 +149,13 @@ def wait_num_outstanding(self, num_outstanding: gl.constexpr):
return self

@gluon.jit
def take_result(self):
def take_result(self, splitn: gl.constexpr = False):
next = MMAv5(gl.to_tensor(False), self.acc_tmem, self.bar, self.counter)
return self.acc_tmem.load(), next
if splitn:
layout: gl.constexpr = self.acc_tmem.get_reg_layout(instr_variant="32x32b_splitn")
return self.acc_tmem.load(layout), next
Comment thread
lezcano marked this conversation as resolved.
else:
return self.acc_tmem.load(), next


def select_mma_impl():
Expand Down Expand Up @@ -620,8 +624,9 @@ def issue_mma_stealb(consumer, mma, bars, a_bufs, b_bufs, stealb: gl.constexpr,


@gluon.jit
def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.constexpr, SchedulerImpl: gl.constexpr,
num_buffers: gl.constexpr, STEALB: gl.constexpr, num_warps: gl.constexpr):
def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, c_half_desc, MMAImpl: gl.constexpr,
SchedulerImpl: gl.constexpr, num_buffers: gl.constexpr, STEALB: gl.constexpr,
num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
Expand All @@ -636,7 +641,8 @@ def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.const
if not STEALB:
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
else:
gl.static_assert(2 * BLOCK_N * BLOCK_K >= BLOCK_M * BLOCK_N, "B tile not large enough to steal")
gl.static_assert(BLOCK_M == BLOCK_K or BLOCK_M == 2 * BLOCK_K,
"expected one or two B tiles to cover the epilogue tile")
bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
for i in gl.static_range(num_buffers):
mbarrier.init(bars.index(i), count=1)
Expand Down Expand Up @@ -687,18 +693,33 @@ def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.const
num_buffers)

mma = mma.wait_num_outstanding(0)
c, mma = mma.take_result()
use_split_n_load: gl.constexpr = STEALB and BLOCK_M != BLOCK_K
c, mma = mma.take_result(splitn=use_split_n_load)
c = c.to(dtype)
if not STEALB:
c_buf = c_smem
tma.store_wait(pendings=0)
c_buf.store(c)
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf)
elif BLOCK_M == BLOCK_K:
c_buf = b_bufs.index(producer % (num_buffers + STEALB))
c_buf.store(c)
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf)
else:
# Steal the next 2 B buffers for the epilogue.
c_buf = b_bufs.index(producer % (num_buffers + STEALB))._reinterpret(dtype, c_desc.block_type.shape,
c_desc.layout)
c_buf.store(c)
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf)
c0, c1 = c.reshape((BLOCK_M, 2, BLOCK_N // 2)).permute(0, 2, 1).split()
c0_buf = b_bufs.index(producer % (num_buffers + STEALB))._reinterpret(shape=c_half_desc.block_type.shape,
layout=c_half_desc.layout)
c1_buf = b_bufs.index(
(producer + 1) % (num_buffers + STEALB))._reinterpret(shape=c_half_desc.block_type.shape,
layout=c_half_desc.layout)
c0_buf.store(c0)
c1_buf.store(c1)
fence_async_shared()
tma.async_copy_shared_to_global(c_half_desc, [epilogue_off_m, epilogue_off_n], c0_buf)
tma.async_copy_shared_to_global(c_half_desc, [epilogue_off_m, epilogue_off_n + BLOCK_N // 2], c1_buf)
tma.store_wait(pendings=0)


Expand All @@ -709,16 +730,19 @@ def persistent_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers,
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
c_half_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N // 2], gl.float16)

a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
c_half_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N // 2], c_half_layout)

num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
num_pid = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)
grid = (min(num_sms, num_pid), )
persistent_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, MMAImpl, SchedulerImpl, num_buffers,
STEALB=num_buffers == 4, num_warps=num_warps)
stealb = num_buffers == 4
persistent_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, c_half_desc, MMAImpl, SchedulerImpl, num_buffers,
STEALB=stealb, num_warps=num_warps)


@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
Expand Down
6 changes: 3 additions & 3 deletions test/Analysis/test-membar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1041,11 +1041,11 @@ tt.func @layout_changed_reinterpret_subslice() {
// CHECK: ttg.barrier local
// CHECK-NEXT: ttg.local_load
%0 = ttg.local_load %subslice1 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> -> tensor<16x16xf16>
%subslice2 = ttg.memdesc_subslice %alloc [16, 0] : !ttg.memdesc<32x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16>
%reinterpreted = ttg.memdesc_reinterpret %subslice2 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable>
%reinterpreted_parent = ttg.memdesc_reinterpret %alloc : !ttg.memdesc<32x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x16xf16, #sharedT, #smem, mutable>
%reinterpreted = ttg.memdesc_subslice %reinterpreted_parent [16, 0] : !ttg.memdesc<32x16xf16, #sharedT, #smem, mutable> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable, 32x16>
// CHECK: ttg.barrier local
// CHECK-NEXT: ttg.local_store
ttg.local_store %cst_store, %reinterpreted : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable>
ttg.local_store %cst_store, %reinterpreted : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable, 32x16>
// CHECK: ttg.barrier local
// CHECK-NEXT: ttg.local_load
%1 = ttg.local_load %subslice1 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> -> tensor<16x16xf16>
Expand Down
Loading
Loading