diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 8fabbca7c7bf..fd96306a9fe5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -307,10 +307,7 @@ struct ReshapeOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { - return emitOptionalError(loc, - "expensive view not supported on reshape op"); - } + assert(!isExpensiveView(op.getSrc().getType(), op.getType())); auto resultTy = cast(op.getType()); auto srcTy = cast(op.getSrc().getType()); auto typeConverter = getTypeConverter(); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index b1022ba72671..3793068d39ea 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -913,18 +913,20 @@ LogicalResult ReshapeOp::verify() { "encodings, or (b) neither does."); } - if (!srcEnc || getAllowReorder()) { + if (!srcEnc) { return success(); } - // Check that we can infer the dst encoding from the src encoding - // and that the inferred dst encoding is the same as the given dst encoding - Attribute inferredDstEnc; + // Check that we can infer the dst encoding from the src encoding and that the + // inferred dst encoding is the same as the given dst encoding. We pass the + // current dst encoding as a hint so that allowReorder reshapes are guaranteed + // to produce the current encoding iff it is valid. + Attribute inferredDstEnc = dstEnc; auto layoutInterface = cast(&srcEnc.getDialect()); auto result = layoutInterface->inferReshapeOpEncoding( srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, - /*allowReorder=*/false, getLoc()); + getAllowReorder(), getLoc()); if (failed(result)) return failure(); return layoutInterface->verifyLayoutsAreEqual( diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index a2de394abc2d..83bd767c57ac 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -248,26 +248,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blockedsrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked1 = #ttg.slice<{dim=0, parent=#blocked}> -#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.slice<{dim=0, parent=#blockedsrc}> +#blocked2 = #ttg.slice<{dim=0, parent=#blockedtrans}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: unary_triton_ops_transitive_nonneg tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { %c10_i32 = arith.constant 5 : i32 %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked> - %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked> - %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked> - %4 = tt.trans %3 {order = array} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans> - %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blockedsrc> + %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<8x2xi32, #blocked> + %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<2x8xi32, #blockedtrans> + %4 = tt.trans %3 {order = array} : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> + %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked> -> tensor<8x2xi32, #blocked> %6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked> %7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2> - %8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1> - %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked> - %10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked> - %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked> + %8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked2> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked2> -> tensor<1x8xi32, #blockedtrans> + %10 = tt.broadcast %9 : tensor<1x8xi32, #blockedtrans> -> tensor<2x8xi32, #blockedtrans> + %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> %12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked> %13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked> %14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked> @@ -293,7 +294,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 94413edc06d1..ba97af29a430 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -288,26 +288,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blockedsrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked1 = #ttg.slice<{dim=0, parent=#blocked}> -#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.slice<{dim=0, parent=#blockedsrc}> +#blocked2 = #ttg.slice<{dim=0, parent=#blockedtrans}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: unary_triton_ops_transitive_nonneg tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { %c10_i32 = arith.constant 5 : i32 %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked> - %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked> - %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked> - %4 = tt.trans %3 {order = array} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans> - %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blockedsrc> + %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<8x2xi32, #blocked> + %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<2x8xi32, #blockedtrans> + %4 = tt.trans %3 {order = array} : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> + %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked> -> tensor<8x2xi32, #blocked> %6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked> %7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2> - %8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1> - %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked> - %10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked> - %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked> + %8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked2> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked2> -> tensor<1x8xi32, #blockedtrans> + %10 = tt.broadcast %9 : tensor<1x8xi32, #blockedtrans> -> tensor<2x8xi32, #blockedtrans> + %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> %12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked> %13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked> %14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked> @@ -333,7 +334,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index b55bc22547d3..bd2b9bea64b3 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -8,7 +8,7 @@ // CHECK: tt.return %[[V]] #blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [0, 1]}> module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { @@ -68,7 +68,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: tt.return %[[V]] #blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [0, 1]}> module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 91c6e3af601a..b89d933221d3 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2199,8 +2199,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr // ----- #blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @permuting_reshape_propagate