diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index ffd6e14ba8aa..1210385209ca 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -340,7 +340,9 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT The `ttg.memdesc_reinterpret` operation reinterprets a memory descriptor as one with a different shape and element type. Because memory descriptors lack strides, this operation is only valid if the original memory descriptor - is contiguous. + is contiguous. Reinterpretation of subviews is not allowed; reinterpret the + parent descriptor and then take a subview of the reinterpreted descriptor + instead. }]; let arguments = (ins TTG_MemDescType:$src); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 8e0b17ce74bb..8d309b3f5d59 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -656,17 +656,37 @@ OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) { LogicalResult MemDescReinterpretOp::verify() { auto srcTy = getSrc().getType(); auto dstTy = getResult().getType(); - auto kBlock = StringAttr::get(getContext(), "block"); - auto getNumBroadcastCTADims = [kBlock](MemDescType ty) { + if (srcTy.getMemorySpace() != dstTy.getMemorySpace()) + return emitError("source and result must have the same memory space"); + if (srcTy.getMutableMemory() != dstTy.getMutableMemory()) + return emitError("source and result must have the same mutability"); + auto isSubview = [](MemDescType ty) { + auto rank = cast(ty.getEncoding()).getRank(); + return ty.getShape().take_back(rank) != ty.getAllocShape().take_back(rank); + }; + if (isSubview(srcTy) || isSubview(dstTy)) + return emitError("source and result must not be subviews; reinterpret the " + "parent descriptor and then take a subview"); + assert((isa( + srcTy.getMemorySpace()) && + "expected shared or tensor memory")); + auto getViewNumBits = [](MemDescType ty) { auto rank = cast(ty.getEncoding()).getRank(); auto layout = toLinearLayout(ty.getAllocShape().take_back(rank), ty.getEncoding()); - auto freeVariableMask = layout.getFreeVariableMasks().lookup(kBlock); - return llvm::popcount(freeVariableMask); + // Shared memory is allocated by offset and TMEM is allocated by column; the + // other physical dimensions do not increase or decrease the allocation. + auto *ctx = ty.getContext(); + bool isSharedMemory = isa(ty.getMemorySpace()); + auto dim = StringAttr::get(ctx, isSharedMemory ? "offset" : "col"); + return layout.getInDimSize(dim) * ty.getElementTypeBitWidth(); }; - if (getNumBroadcastCTADims(srcTy) != getNumBroadcastCTADims(dstTy)) - return emitError( - "source and result must have the same number of broadcast CTA dims"); + auto srcNumBits = getViewNumBits(srcTy); + auto dstNumBits = getViewNumBits(dstTy); + if (srcNumBits != dstNumBits) + return emitError() << "source and result must have the same logical " + "storage size (" + << srcNumBits << " vs " << dstNumBits << ")"; return success(); } diff --git a/python/examples/gluon/01-attention-forward.py b/python/examples/gluon/01-attention-forward.py index 2ea0234f915b..8a0e40c95866 100644 --- a/python/examples/gluon/01-attention-forward.py +++ b/python/examples/gluon/01-attention-forward.py @@ -383,26 +383,25 @@ def get_loop_bounds(self, STAGE: gl.constexpr): @gluon.jit def _borrow_s_as_p(config, s_tmem): - p_tmem = s_tmem.slice(0, config.BLOCK_N // 2) - return p_tmem._reinterpret(config.dtype, config.qk_shape, config.p_tmem_layout) + p_tmem = s_tmem._reinterpret(config.dtype, [config.SPLIT_M, 2 * config.BLOCK_N], config.p_tmem_layout) + return p_tmem.slice(0, config.BLOCK_N) @gluon.jit def _borrow_s_as_alpha(config, s_tmem): - alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1) alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M_PER_CTA, 1], col_stride=1, cga_layout=config.CGA_LAYOUT, two_ctas=gl.num_ctas() > 1) - return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout) + alpha_tmem = s_tmem._reinterpret(layout=alpha_layout) + return alpha_tmem.slice(config.BLOCK_N // 2, 1) @gluon.jit def _borrow_s_for_epilogue(config, s_tmem): - m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1) - l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1) layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M_PER_CTA, 1], col_stride=1, cga_layout=config.CGA_LAYOUT, two_ctas=gl.num_ctas() > 1) - m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout) - l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout) + s_tmem = s_tmem._reinterpret(layout=layout) + m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1) + l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1) return m_i_tmem, l_i_tmem @@ -880,7 +879,7 @@ def attention_kernel( # q_chnl = get_desc_channel(desc_q, num_buffers=2) kv_chnl = get_desc_channel(desc_k, num_buffers=config.num_kv_buffers) - v_mem = kv_chnl.mem._reinterpret(desc_v.dtype, [config.num_kv_buffers] + desc_v.block_type.shape, desc_v.layout) + v_mem = kv_chnl.mem._reinterpret(layout=desc_v.layout) o_chnl = TensorMemoryChannel.alloc(config.o_shape, gl.float32, config.o_tmem_layout, num_buffers=2, producer_two_ctas=gl.num_ctas() > 1) epi_chnl = SharedMemoryChannel.alloc(config.o_shape, config.dtype, gl.constexpr(desc_o.layout), num_buffers=2) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 539e45b62b67..8db7e1aa7351 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -1522,7 +1522,7 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_ tcgen05_commit(barrier) mbarrier.wait(barrier, phase=0) tmem_alias: ttgl.constexpr = TensorMemoryLayout((num_rows, num_cols), col_stride=1) - tmem = tmem._reinterpret(ttgl.int8, (num_rows, num_cols), tmem_alias) + tmem = tmem._reinterpret(shape=(num_rows, num_cols), layout=tmem_alias) value = tmem.load(blocked) ttgl.store(ttgl.set_auto_layout(out_ptrs, blocked), value) @@ -1567,17 +1567,19 @@ def kernel(s_ptr, out_ptr): s_tmem.store(s) o_tmem.store(s) - p_tmem = s_tmem.slice(0, N // 2)._reinterpret(ttgl.float16, [BLOCK_M, N], tmem_layout) + p_tmem_parent = s_tmem._reinterpret(ttgl.float16, [BLOCK_M, 2 * N], tmem_layout) + p_tmem = p_tmem_parent.slice(0, N) p_tmem.store(ttgl.full((BLOCK_M, N), 0.0, dtype=ttgl.float16, layout=layout)) d1_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, 2), col_stride=1) - m_tmem = s_tmem.slice(N // 4, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout) + d1_tmem_parent = s_tmem._reinterpret(layout=d1_tmem_layout) + m_tmem = d1_tmem_parent.slice(N // 2, 2) d1_layout: ttgl.constexpr = m_tmem.get_reg_layout() m_tmem.store(ttgl.full((BLOCK_M, 2), 2.0, dtype=ttgl.float32, layout=d1_layout)) - l_tmem = s_tmem.slice(N // 4 + 2, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout) + l_tmem = d1_tmem_parent.slice(N // 2 + 4, 2) l_tmem.store(ttgl.full((BLOCK_M, 2), 3.0, dtype=ttgl.float32, layout=d1_layout)) - a_tmem = s_tmem.slice(N // 4 + 4, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout) + a_tmem = d1_tmem_parent.slice(N // 2 + 8, 2) a_tmem.store(ttgl.full((BLOCK_M, 2), 4.0, dtype=ttgl.float32, layout=d1_layout)) s = s_tmem.load() @@ -1616,7 +1618,7 @@ def kernel(s_ptr, out_ptr): # TMEM[0:16] = [s0, s1] # TMEM[16:32] = [s2, s3] # - # Thus slicing S at N//4 will obtain an offset to the beginning of s1. + # Thus the narrow parent view is sliced at offsets that map back to s1. out_ref[:, 32:34] = 2.0 out_ref[:, 34:36] = 3.0 out_ref[:, 36:38] = 4.0 @@ -1727,7 +1729,8 @@ def kernel(in_ptr, out_ptr): smem_layout_2d: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) smem = ttgl.allocate_shared_memory(ttgl.int8, [BLOCK], smem_layout_1d) smem_slice0 = smem.slice(0, SPLIT_BLOCK) - smem_slice1 = smem.slice(SPLIT_BLOCK, SPLIT_BLOCK)._reinterpret(ttgl.int32, [XBLOCK, YBLOCK], smem_layout_2d) + smem_i32 = smem._reinterpret(ttgl.int32, [2 * XBLOCK, YBLOCK], smem_layout_2d) + smem_slice1 = smem_i32.slice(XBLOCK, XBLOCK, dim=0) offs = ttgl.arange(0, XBLOCK)[:, None] * YBLOCK + ttgl.arange(0, YBLOCK)[None, :] blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, NUM_THREADS], [1, 4], [1, 0]) diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 295ff4d23e29..960cb7f3a3f6 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -445,7 +445,7 @@ def shared_memory_cast_kernel(): smem = ttgl.allocate_shared_memory(ttgl.float16, [32, 1, 4, 64], layout_b) smem.reshape((128, 64)) - smem._reinterpret(ttgl.int8, [1024], ttgl.SwizzledSharedLayout(1, 1, 1, [0])) + smem._reinterpret(ttgl.int8, [16384], ttgl.SwizzledSharedLayout(1, 1, 1, [0])) @pytest.mark.parametrize("target", ALL_TARGETS) @@ -468,7 +468,7 @@ def test_shared_memory_cast(target): tt.call @test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False__NVMMALAS128_256ASMD(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) -> () %3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> %4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable> - %5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable> + %5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16384xi8, #shared4, #smem, mutable> tt.return } tt.func private @test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False__NVMMALAS128_256ASMD(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) attributes {noinline = true} { diff --git a/python/triton/experimental/gluon/language/_core.py b/python/triton/experimental/gluon/language/_core.py index 91a48472f03e..9a294f1cec88 100644 --- a/python/triton/experimental/gluon/language/_core.py +++ b/python/triton/experimental/gluon/language/_core.py @@ -485,21 +485,22 @@ def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descr return _semantic.memdesc_reshape(self, shape) @builtin - def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + def _reinterpret(self, dtype=None, shape=None, layout=None, + _semantic: GluonSemantic = None) -> shared_memory_descriptor: """ Reinterpret the shared memory descriptor as a different dtype, shape, or layout. Args: - dtype (dtype): The new data type. - shape (List[int]): The new shape. - layout (SharedLayout): The new layout. + dtype (dtype): The new data type. Defaults to the descriptor dtype. + shape (List[int]): The new shape. Defaults to the descriptor shape. + layout (SharedLayout): The new layout. Defaults to the descriptor layout. Returns: shared_memory_descriptor: Descriptor with updated type and layout. """ - dtype = _unwrap_if_constexpr(dtype) - shape = [_unwrap_if_constexpr(s) for s in shape] - layout = _unwrap_if_constexpr(layout) + dtype = self.dtype if dtype is None else _unwrap_if_constexpr(dtype) + shape = self.shape if shape is None else [_unwrap_if_constexpr(s) for s in shape] + layout = self.layout if layout is None else _unwrap_if_constexpr(layout) return _semantic.memdesc_reinterpret(self, dtype, shape, layout) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index dd3d52502809..3da1220133ee 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -395,21 +395,22 @@ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descrip return ret @builtin - def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor: + def _reinterpret(self, dtype=None, shape=None, layout=None, + _semantic: GluonSemantic = None) -> tensor_memory_descriptor: """ Reinterpret tensor memory descriptor with a new dtype, shape, and layout. Args: - dtype (dtype): The new data type. - shape (Sequence[int]): The new shape. - layout (TensorMemoryLayout): The new layout. + dtype (dtype): The new data type. Defaults to the descriptor dtype. + shape (Sequence[int]): The new shape. Defaults to the descriptor shape. + layout (TensorMemoryLayout): The new layout. Defaults to the descriptor layout. Returns: tensor_memory_descriptor: Descriptor with updated type and layout. """ - dtype = _unwrap_if_constexpr(dtype) - shape = [_unwrap_if_constexpr(s) for s in shape] - layout = _unwrap_if_constexpr(layout) + dtype = self.dtype if dtype is None else _unwrap_if_constexpr(dtype) + shape = self.shape if shape is None else [_unwrap_if_constexpr(s) for s in shape] + layout = self.layout if layout is None else _unwrap_if_constexpr(layout) ty = tensor_memory_descriptor_type(dtype, shape, layout, shape) handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle) diff --git a/python/tutorials/gluon/07-persistence.py b/python/tutorials/gluon/07-persistence.py index d08e76de52fc..f8210d23d33d 100644 --- a/python/tutorials/gluon/07-persistence.py +++ b/python/tutorials/gluon/07-persistence.py @@ -115,7 +115,7 @@ def wait_num_outstanding(self, num_outstanding: gl.constexpr): # Take the result and reset the accumulator. @gluon.jit - def take_result(self): + def take_result(self, splitn: gl.constexpr = False): return self.acc, WGMMA(self.acc, gl.to_tensor(False)) @@ -149,9 +149,13 @@ def wait_num_outstanding(self, num_outstanding: gl.constexpr): return self @gluon.jit - def take_result(self): + def take_result(self, splitn: gl.constexpr = False): next = MMAv5(gl.to_tensor(False), self.acc_tmem, self.bar, self.counter) - return self.acc_tmem.load(), next + if splitn: + layout: gl.constexpr = self.acc_tmem.get_reg_layout(instr_variant="32x32b_splitn") + return self.acc_tmem.load(layout), next + else: + return self.acc_tmem.load(), next def select_mma_impl(): @@ -620,8 +624,9 @@ def issue_mma_stealb(consumer, mma, bars, a_bufs, b_bufs, stealb: gl.constexpr, @gluon.jit -def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.constexpr, SchedulerImpl: gl.constexpr, - num_buffers: gl.constexpr, STEALB: gl.constexpr, num_warps: gl.constexpr): +def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, c_half_desc, MMAImpl: gl.constexpr, + SchedulerImpl: gl.constexpr, num_buffers: gl.constexpr, STEALB: gl.constexpr, + num_warps: gl.constexpr): BLOCK_M: gl.constexpr = c_desc.block_type.shape[0] BLOCK_N: gl.constexpr = c_desc.block_type.shape[1] BLOCK_K: gl.constexpr = a_desc.block_type.shape[1] @@ -636,7 +641,8 @@ def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.const if not STEALB: c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout) else: - gl.static_assert(2 * BLOCK_N * BLOCK_K >= BLOCK_M * BLOCK_N, "B tile not large enough to steal") + gl.static_assert(BLOCK_M == BLOCK_K or BLOCK_M == 2 * BLOCK_K, + "expected one or two B tiles to cover the epilogue tile") bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout()) for i in gl.static_range(num_buffers): mbarrier.init(bars.index(i), count=1) @@ -687,18 +693,33 @@ def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.const num_buffers) mma = mma.wait_num_outstanding(0) - c, mma = mma.take_result() + use_split_n_load: gl.constexpr = STEALB and BLOCK_M != BLOCK_K + c, mma = mma.take_result(splitn=use_split_n_load) c = c.to(dtype) if not STEALB: c_buf = c_smem tma.store_wait(pendings=0) + c_buf.store(c) + fence_async_shared() + tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf) + elif BLOCK_M == BLOCK_K: + c_buf = b_bufs.index(producer % (num_buffers + STEALB)) + c_buf.store(c) + fence_async_shared() + tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf) else: # Steal the next 2 B buffers for the epilogue. - c_buf = b_bufs.index(producer % (num_buffers + STEALB))._reinterpret(dtype, c_desc.block_type.shape, - c_desc.layout) - c_buf.store(c) - fence_async_shared() - tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf) + c0, c1 = c.reshape((BLOCK_M, 2, BLOCK_N // 2)).permute(0, 2, 1).split() + c0_buf = b_bufs.index(producer % (num_buffers + STEALB))._reinterpret(shape=c_half_desc.block_type.shape, + layout=c_half_desc.layout) + c1_buf = b_bufs.index( + (producer + 1) % (num_buffers + STEALB))._reinterpret(shape=c_half_desc.block_type.shape, + layout=c_half_desc.layout) + c0_buf.store(c0) + c1_buf.store(c1) + fence_async_shared() + tma.async_copy_shared_to_global(c_half_desc, [epilogue_off_m, epilogue_off_n], c0_buf) + tma.async_copy_shared_to_global(c_half_desc, [epilogue_off_m, epilogue_off_n + BLOCK_N // 2], c1_buf) tma.store_wait(pendings=0) @@ -709,16 +730,19 @@ def persistent_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16) b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16) c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16) + c_half_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N // 2], gl.float16) a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout) b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout) c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout) + c_half_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N // 2], c_half_layout) num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count num_pid = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N) grid = (min(num_sms, num_pid), ) - persistent_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, MMAImpl, SchedulerImpl, num_buffers, - STEALB=num_buffers == 4, num_warps=num_warps) + stealb = num_buffers == 4 + persistent_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, c_half_desc, MMAImpl, SchedulerImpl, num_buffers, + STEALB=stealb, num_warps=num_warps) @pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)]) diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index ebd26997d4f8..6095c66c7eff 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1041,11 +1041,11 @@ tt.func @layout_changed_reinterpret_subslice() { // CHECK: ttg.barrier local // CHECK-NEXT: ttg.local_load %0 = ttg.local_load %subslice1 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> -> tensor<16x16xf16> - %subslice2 = ttg.memdesc_subslice %alloc [16, 0] : !ttg.memdesc<32x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> - %reinterpreted = ttg.memdesc_reinterpret %subslice2 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable> + %reinterpreted_parent = ttg.memdesc_reinterpret %alloc : !ttg.memdesc<32x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x16xf16, #sharedT, #smem, mutable> + %reinterpreted = ttg.memdesc_subslice %reinterpreted_parent [16, 0] : !ttg.memdesc<32x16xf16, #sharedT, #smem, mutable> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable, 32x16> // CHECK: ttg.barrier local // CHECK-NEXT: ttg.local_store - ttg.local_store %cst_store, %reinterpreted : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable> + ttg.local_store %cst_store, %reinterpreted : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #sharedT, #smem, mutable, 32x16> // CHECK: ttg.barrier local // CHECK-NEXT: ttg.local_load %1 = ttg.local_load %subslice1 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable, 32x16> -> tensor<16x16xf16> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 167edfdfec07..417365b3d9d1 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2681,22 +2681,7 @@ tt.func private @memdesc_reinterpret(%arg0: !ttg.memdesc<4x1024xi64, #shared0, # // CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue %arg0[0] // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) // CHECK: [[PTR:%.*]] = llvm.getelementptr [[BASE_PTR]][[[C0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i64 - ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<4x1024xi64, #shared0, #ttg.shared_memory, mutable> -> !ttg.memdesc<4x4x4xi32, #shared1, #ttg.shared_memory, mutable> - // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) - // CHECK: [[S0:%.*]] = llvm.mlir.undef - // CHECK: [[S1:%.*]] = llvm.insertvalue [[PTR]], [[S0]][0] - // CHECK: [[S2:%.*]] = llvm.insertvalue [[C0]], [[S1]][1] - // CHECK: [[S3:%.*]] = llvm.insertvalue [[C0]], [[S2]][2] - // CHECK: [[S4:%.*]] = llvm.insertvalue [[C0]], [[S3]][3] - tt.return -} - -// CHECK-LABEL: @memdesc_reinterpret_affine -tt.func private @memdesc_reinterpret_affine(%arg0: !ttg.memdesc<4x1024xi64, #shared0, #ttg.shared_memory, mutable, 32x1024>) { - // CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue %arg0[0] - // CHECK: [[OFFSET:%.*]] = llvm.xor - // CHECK: [[PTR:%.*]] = llvm.getelementptr [[BASE_PTR]][[[OFFSET]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i64 - ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<4x1024xi64, #shared0, #ttg.shared_memory, mutable, 32x1024> -> !ttg.memdesc<4x4x4xi32, #shared1, #ttg.shared_memory, mutable> + ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<4x1024xi64, #shared0, #ttg.shared_memory, mutable> -> !ttg.memdesc<4x4x2048xi32, #shared1, #ttg.shared_memory, mutable> // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) // CHECK: [[S0:%.*]] = llvm.mlir.undef // CHECK: [[S1:%.*]] = llvm.insertvalue [[PTR]], [[S0]][0] diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 5cb80f25c5d1..3e43eaf1b7b4 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -810,10 +810,10 @@ tt.func @tc_gen5_commit(%arg0: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %pr module attributes {"ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @reinterpret -tt.func private @reinterpret(%arg0: !ttg.memdesc<128x32xf32, #tmem_f32, #ttng.tensor_memory>) -> !ttg.memdesc<256x32xf16, #tmem_f16, #ttng.tensor_memory> { - %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x32xf32, #tmem_f32, #ttng.tensor_memory> -> !ttg.memdesc<256x32xf16, #tmem_f16, #ttng.tensor_memory> +tt.func private @reinterpret(%arg0: !ttg.memdesc<128x32xf32, #tmem_f32, #ttng.tensor_memory>) -> !ttg.memdesc<128x32xf16, #tmem_f16, #ttng.tensor_memory> { + %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x32xf32, #tmem_f32, #ttng.tensor_memory> -> !ttg.memdesc<128x32xf16, #tmem_f16, #ttng.tensor_memory> // CHECK-NEXT: return %arg0 - tt.return %0 : !ttg.memdesc<256x32xf16, #tmem_f16, #ttng.tensor_memory> + tt.return %0 : !ttg.memdesc<128x32xf16, #tmem_f16, #ttng.tensor_memory> } } diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 685ab7a8ba94..3bf4f3179dd9 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -209,31 +209,45 @@ tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared1d, #s // ----- -#shared_multicast = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> -#shared_local = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> #smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 2 : i32} { - tt.func public @memdesc_reinterpret_changed_broadcast_count(%arg0: !ttg.memdesc<128x128xf16, #shared_multicast, #smem>) { - // expected-error @+1 {{source and result must have the same number of broadcast CTA dims}} - %a = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x128xf16, #shared_multicast, #smem> -> !ttg.memdesc<128x128xf16, #shared_local, #smem> +tt.func public @memdesc_reinterpret_changed_storage_size(%arg0: !ttg.memdesc<8x16xf16, #shared, #smem>) { + // expected-error @+1 {{source and result must have the same logical storage size}} + %a = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<8x16xf16, #shared, #smem> -> !ttg.memdesc<8x16xf32, #shared, #smem> tt.return - } } // ----- -#shared_bc0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 1], [0, 0]]}> -#shared_bc1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0], [0, 0]]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> #smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 4 : i32} { - tt.func public @memdesc_reinterpret_changed_broadcast_mask(%arg0: !ttg.memdesc<128x128xf16, #shared_bc0, #smem>) { - %a = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x128xf16, #shared_bc0, #smem> -> !ttg.memdesc<128x128xf16, #shared_bc1, #smem> +#tmem = #ttng.tensor_memory_encoding +tt.func public @memdesc_reinterpret_changed_memory_space(%arg0: !ttg.memdesc<128x128xf16, #shared, #smem>) { + // expected-error @+1 {{source and result must have the same memory space}} + %a = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory> tt.return - } } // ----- +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @memdesc_reinterpret_changed_mutability(%arg0: !ttg.memdesc<8x16xf16, #shared, #smem>) { + // expected-error @+1 {{source and result must have the same mutability}} + %a = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<8x16xf16, #shared, #smem> -> !ttg.memdesc<8x16xf16, #shared, #smem, mutable> + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @memdesc_reinterpret_subview(%arg0: !ttg.memdesc<8x16xf16, #shared, #smem, 16x16>) { + // expected-error @+1 {{source and result must not be subviews}} + %a = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<8x16xf16, #shared, #smem, 16x16> -> !ttg.memdesc<8x16xf16, #shared, #smem> + tt.return +} + #mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1], instrShape = [16, 8]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index cfd844e518f5..0b435997df38 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -736,6 +736,17 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @memdesc_reinterpret_changed_storage_size_tmem(%arg0: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>) { + // expected-error @+1 {{source and result must have the same logical storage size}} + %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory> + tt.return + } +} + +// ----- + #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {