diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index 6c0c9fcc9341..931777bfa3a6 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -116,18 +116,20 @@ struct MoveBroadcastAfterElementwisePattern auto operands = op->getOperands(); bool seenBroadcast = false; - Type srcType; + ArrayRef srcShape; for (auto operand : operands) { auto definingOp = operand.getDefiningOp(); if (!definingOp) { return mlir::failure(); } - + auto getSrcShape = [](triton::BroadcastOp b) { + return b.getSrc().getType().cast().getShape(); + }; if (auto broadcastOp = llvm::dyn_cast(definingOp)) { if (!seenBroadcast) { seenBroadcast = true; - srcType = broadcastOp.getSrc().getType(); - } else if (srcType != broadcastOp.getSrc().getType()) { + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { // If the broadcast have different types we cannot re-order. return mlir::failure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 2d8ca362465a..b7f88948b982 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -929,7 +929,7 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { auto isExtOrBroadcastOp = [](Operation *op) { return isa(op); + triton::BroadcastOp, triton::ExpandDimsOp>(op); }; // 1. Take a backward slice of all the tensor dependencies. SetVector slice; @@ -950,8 +950,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { if (isExtOrBroadcastOp(op)) { SetVector tempSlice; DenseMap tempLayout; + std::optional srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; LogicalResult result = getRematerializableSlice( - op->getOperand(0), layout[v], tempSlice, tempLayout); + op->getOperand(0), *srcEncoding, tempSlice, tempLayout); // If we can rematerialize the rest of the ext slice we can ignore this // ext as it won't need a convert. if (result.succeeded()) { @@ -969,13 +972,16 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { if (extOrBroadcatOp == nullptr) return; + std::optional srcEncoding = + inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]); + if (!srcEncoding) + return; // Move the convert before the ext op and rewrite the slice. OpBuilder builder(extOrBroadcatOp); auto tensorType = extOrBroadcatOp->getOperand(0).getType().cast(); - auto newType = - RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), - layout[extOrBroadcatOp->getResult(0)]); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); IRMapping mapping; diff --git a/test/Triton/reorder-broadcast.mlir b/test/Triton/reorder-broadcast.mlir index d5e054337a08..201b81b1e746 100644 --- a/test/Triton/reorder-broadcast.mlir +++ b/test/Triton/reorder-broadcast.mlir @@ -53,3 +53,15 @@ tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tenso tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32> } + +// CHECK-LABEL: @test_broadcast_mix_type_op_pattern +tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) { + // CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast1 = tt.splat %arg1 : (f32) -> tensor<128x128xf32> + %cond = tt.broadcast %arg3 : (tensor<128x1xi1>) -> tensor<128x128xi1> + %sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32> + + tt.return %sel : tensor<128x128xf32> +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 5c3fdd6b9ba6..8f5685ae8649 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1189,10 +1189,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: reduce_cvt2 // Match the reduction +// CHECK-NOT: triton_gpu.convert_layout // CHECK: tt.reduce // CHECK-SAME: axis = 1 -// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>> // CHECK: triton_gpu.convert_layout +// CHECK: tt.expand_dims // CHECK-NOT: triton_gpu.convert_layout // CHECK: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>