diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 8d309b3f5d59..76ca1033904b 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -674,12 +674,17 @@ LogicalResult MemDescReinterpretOp::verify() { auto rank = cast(ty.getEncoding()).getRank(); auto layout = toLinearLayout(ty.getAllocShape().take_back(rank), ty.getEncoding()); - // Shared memory is allocated by offset and TMEM is allocated by column; the - // other physical dimensions do not increase or decrease the allocation. + int64_t numLayoutCopies = 1; + for (int64_t dim : ty.getAllocShape().drop_back(rank)) + numLayoutCopies *= dim; + // Shared memory is allocated by offset and TMEM is allocated by column; + // prefix dimensions outside the layout-ranked suffix represent separate + // copies of that logical allocation. auto *ctx = ty.getContext(); bool isSharedMemory = isa(ty.getMemorySpace()); auto dim = StringAttr::get(ctx, isSharedMemory ? "offset" : "col"); - return layout.getInDimSize(dim) * ty.getElementTypeBitWidth(); + return numLayoutCopies * layout.getInDimSize(dim) * + ty.getElementTypeBitWidth(); }; auto srcNumBits = getViewNumBits(srcTy); auto dstNumBits = getViewNumBits(dstTy); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 417365b3d9d1..e99976f1ab5d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2681,7 +2681,7 @@ tt.func private @memdesc_reinterpret(%arg0: !ttg.memdesc<4x1024xi64, #shared0, # // CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue %arg0[0] // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) // CHECK: [[PTR:%.*]] = llvm.getelementptr [[BASE_PTR]][[[C0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i64 - ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<4x1024xi64, #shared0, #ttg.shared_memory, mutable> -> !ttg.memdesc<4x4x2048xi32, #shared1, #ttg.shared_memory, mutable> + ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<4x1024xi64, #shared0, #ttg.shared_memory, mutable> -> !ttg.memdesc<4x1x2048xi32, #shared1, #ttg.shared_memory, mutable> // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) // CHECK: [[S0:%.*]] = llvm.mlir.undef // CHECK: [[S1:%.*]] = llvm.insertvalue [[PTR]], [[S0]][0] diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index f292bb9d5010..f547c487109c 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -159,6 +159,27 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // ----- +#shared1d = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#shared2d = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: memdesc_reinterpret_layout_rank_increase + // CHECK: ttg.memdesc_reinterpret + tt.func @memdesc_reinterpret_layout_rank_increase(%arg0 : !ttg.memdesc<32x2xi32, #shared1d, #smem, mutable>) { + %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<32x2xi32, #shared1d, #smem, mutable> -> !ttg.memdesc<32x2xi32, #shared2d, #smem, mutable> + tt.return + } + + // CHECK-LABEL: memdesc_reinterpret_layout_rank_decrease + // CHECK: ttg.memdesc_reinterpret + tt.func @memdesc_reinterpret_layout_rank_decrease(%arg0 : !ttg.memdesc<32x2xi32, #shared2d, #smem, mutable>) { + %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<32x2xi32, #shared2d, #smem, mutable> -> !ttg.memdesc<32x2xi32, #shared1d, #smem, mutable> + tt.return + } +} + +// ----- + // CHECK: #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CGALayout = {{\[\[1, 0, 0, 0, 0\]\]}}}> #shared_rank_5 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0, 0, 0, 0]]}> #smem = #ttg.shared_memory