diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 31b4d65651d8..b3d8395a2b72 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -1025,6 +1025,55 @@ struct DistributionInfo { bool vectorizable = false; }; +// Generally parallel loops are partitionable, however if the dims for it +// are not present in a producer of the compute op within the dispatch and the +// results of that producer op are also returned from the dispatch +// then that dim is not partitioned as codegen for this is unsupported. +static SmallVector +getSupportedPartitionableLoops(linalg::LinalgOp linalgOp) { + SmallVector partitionableLoops; + linalgOp.getParallelDims(partitionableLoops); + SmallVector producerOperands; + for (auto operand : linalgOp.getDpsInputOperands()) { + auto producerOp = operand->get().getDefiningOp(); + if (!producerOp) { + continue; + } + + for (Operation *user : producerOp->getUsers()) { + if (isa(user)) { + producerOperands.push_back(operand); + break; + } + } + } + if (producerOperands.empty()) { + return partitionableLoops; + } + // If we have producer operands then we need to confirm that all of them + // also have the the partitionableLoop dims if not we skip that dim. + SmallVector finalPartitionableLoops; + for (auto dim : partitionableLoops) { + bool dimFound = false; + for (auto operand : producerOperands) { + AffineMap IndexingMap = linalgOp.getMatchingIndexingMap(operand); + if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { + auto dimExpr = dyn_cast(expr); + return dimExpr && dimExpr.getPosition() == dim; + })) { + dimFound = true; + } else { + dimFound = false; + break; + } + } + if (dimFound) { + finalPartitionableLoops.push_back(dim); + } + } + return finalPartitionableLoops; +} + static FailureOr collectOpDistributionInfo(Operation *op) { DistributionInfo distInfo; // MapScatterOp doesn't fit the LinalgOp interface, so use special case logic @@ -1054,19 +1103,17 @@ static FailureOr collectOpDistributionInfo(Operation *op) { } auto linalgOp = dyn_cast(op); - // Bail out on multi result cases as consumer fusion currently does not - // support multi result ops. - if (!linalgOp || linalgOp.getNumDpsInits() != 1) { + if (!linalgOp) { return failure(); } // This pipeline requires tensor semantics. Also fail for gather semantics // for now to simplify tile + fuse. - if (!linalgOp.hasPureTensorSemantics() || linalgOp.hasIndexSemantics()) { + if (!linalgOp.hasPureTensorSemantics() || + LinalgExt::isGatherlikeOp(linalgOp)) { return failure(); } - - linalgOp.getParallelDims(distInfo.partitionableLoops); + distInfo.partitionableLoops = getSupportedPartitionableLoops(linalgOp); // Bail out if op is not tilable. if (distInfo.partitionableLoops.empty()) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index b40703dc26ab..bad0e8500011 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -870,3 +870,120 @@ func.func @dyn_parallel_reduction(%arg0: tensor) -> tensor { // CHECK-SAME: reduction = [0, 4] // CHECK-SAME: thread = [1, 0] // CHECK-SAME: workgroup = [64, 0] + +// ----- +func.func @multi_result_index_generic_with_scatterfusion(%arg0: tensor<4x?x32x8xf16>, %arg1: tensor<4x?xi64>) -> (tensor, tensor<4x?x32x8xf8E4M3FNUZ>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c1 : tensor<4x?x32x8xf16> + %0 = tensor.empty(%dim) : tensor<4x?x32x8xf8E4M3FNUZ> + %1 = tensor.empty(%dim) : tensor<4x?x8x32xf8E4M3FNUZ> + %2 = tensor.empty(%dim) : tensor + %3:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<4x?x32x8xf16>) + outs(%0, %1 : tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>) { + ^bb0(%in: f16, %out: f8E4M3FNUZ, %out_0: f8E4M3FNUZ): + %5 = linalg.index 1 : index + %6 = arith.mulf %in, %in : f16 + %7 = arith.cmpi eq, %5, %c0 : index + %8 = arith.select %7, %6, %in : f16 + %9 = arith.truncf %8 : f16 to f8E4M3FNUZ + linalg.yield %9, %9 : f8E4M3FNUZ, f8E4M3FNUZ + } -> (tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>) + %4 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%3#1, %arg1 : tensor<4x?x8x32xf8E4M3FNUZ>, tensor<4x?xi64>) outs(%2 : tensor) { + ^bb0(%arg2: f8E4M3FNUZ, %arg3: f8E4M3FNUZ): + iree_linalg_ext.yield %arg2 : f8E4M3FNUZ + } -> tensor + return %4, %3#0 : tensor, tensor<4x?x32x8xf8E4M3FNUZ> +} + +// CHECK-LABEL: func.func @multi_result_index_generic_with_scatterfusion +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: thread = [1, 1, 1, 4] +// CHECK-SAME: workgroup = [1, 1, 32, 8] + +// ----- +func.func @producer_broadcasted(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>) -> tensor<4x8xi64> { + %c59_i64 = arith.constant 59 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %0 = tensor.empty() : tensor<4x8xi64> + %1 = tensor.empty() : tensor<4xi64> + %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : tensor<4xi64>) outs(%1 : tensor<4xi64>) { + ^bb0(%in: i64, %out: i64): + %4 = arith.addi %in, %c59_i64 : i64 + %5 = arith.muli %4, %c2_i64 : i64 + linalg.yield %5 : i64 + } -> tensor<4xi64> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%2, %arg1 : tensor<4xi64>, tensor<4xi64>) outs(%0 : tensor<4x8xi64>) { + ^bb0(%in: i64, %in_0: i64, %out: i64): + %4 = linalg.index 1 : index + %5 = arith.index_cast %4 : index to i64 + %6 = arith.muli %in, %c8_i64 : i64 + %7 = arith.addi %6, %5 : i64 + %8 = arith.muli %7, %c32_i64 : i64 + %9 = arith.addi %8, %in_0 : i64 + linalg.yield %9 : i64 + } -> tensor<4x8xi64> + return %3 : tensor<4x8xi64> +} + +// CHECK-LABEL: func.func @producer_broadcasted +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK: thread = [1, 1] +// CHECK-SAME: workgroup = [4, 16] + +// ----- +func.func @producer_broadcasted_and_stored_to_buffer(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>, %arg2 : memref<4xi64>) -> tensor<4x8xi64> { + %c59_i64 = arith.constant 59 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %0 = tensor.empty() : tensor<4x8xi64> + %1 = tensor.empty() : tensor<4xi64> + %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : tensor<4xi64>) outs(%1 : tensor<4xi64>) { + ^bb0(%in: i64, %out: i64): + %4 = arith.addi %in, %c59_i64 : i64 + %5 = arith.muli %4, %c2_i64 : i64 + linalg.yield %5 : i64 + } -> tensor<4xi64> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%2, %arg1 : tensor<4xi64>, tensor<4xi64>) outs(%0 : tensor<4x8xi64>) { + ^bb0(%in: i64, %in_0: i64, %out: i64): + %4 = linalg.index 1 : index + %5 = arith.index_cast %4 : index to i64 + %6 = arith.muli %in, %c8_i64 : i64 + %7 = arith.addi %6, %5 : i64 + %8 = arith.muli %7, %c32_i64 : i64 + %9 = arith.addi %8, %in_0 : i64 + linalg.yield %9 : i64 + } -> tensor<4x8xi64> + iree_codegen.store_to_buffer %2, %arg2 : tensor<4xi64> into memref<4xi64> + return %3 : tensor<4x8xi64> +} + +// CHECK-LABEL: func.func @producer_broadcasted +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK: reduction = [0, 4] +// CHECK-SAME: thread = [1, 0] +// CHECK-SAME: workgroup = [64, 0] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 9ee3577f0b00..1d43a769538f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -1390,3 +1390,64 @@ hal.executable public @main { // CHECK: scf.forall {{.*}} in (2048, 8) { // CHECK: %[[READ:.+]] = vector.transfer_read{{.*}}: memref<2048x2048xf32, #amdgpu.address_space>, vector<4xf32> // CHECK: vector.scatter {{.*}} %[[READ]] : memref<4194304xf32, #amdgpu.address_space>, vector<4xindex>, vector<4xi1>, vector<4xf32> + + +// ----- +#translation_info = #iree_codegen.translation_info +#lowering_config = #iree_gpu.lowering_config<{thread = [1, 1, 1, 4], workgroup = [1, 1, 32, 8]}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +hal.executable public @multi_result_index_generic_with_scatter_fusion { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @multi_result_index_generic_with_scatter_fusion ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) { + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%arg1, %arg2, %arg3) + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @multi_result_index_generic_with_scatter_fusion() attributes {translation_info = #translation_info} { + %c32_i64 = arith.constant 32 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %24 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c1) flags("ReadOnly|Indirect") : memref<4x?x32x8xf16, #hal.descriptor_type>{%c1} + %25 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c1) flags("ReadOnly|Indirect") : memref<4x?xi64, #hal.descriptor_type>{%c1} + %26 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref>{%c1} + %27 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c1) flags(Indirect) : memref<4x?x32x8xf8E4M3FNUZ, strided<[?, 256, 8, 1], offset: ?>, #hal.descriptor_type>{%c1} + %28 = iree_codegen.load_from_buffer %24 : memref<4x?x32x8xf16, #hal.descriptor_type> -> tensor<4x?x32x8xf16> + %29 = iree_codegen.load_from_buffer %25 : memref<4x?xi64, #hal.descriptor_type> -> tensor<4x?xi64> + %30 = tensor.empty(%c1) : tensor + %31 = tensor.empty(%c1) : tensor<4x?x8x32xf8E4M3FNUZ> + %32 = tensor.empty(%c1) : tensor<4x?x32x8xf8E4M3FNUZ> + %33:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%28 : tensor<4x?x32x8xf16>) outs(%32, %31 : tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>) attrs = {lowering_config = #lowering_config} { + ^bb0(%in: f16, %out: f8E4M3FNUZ, %out_0: f8E4M3FNUZ): + %35 = linalg.index 1 : index + %36 = arith.mulf %in, %in : f16 + %37 = arith.cmpi eq, %35, %c0 : index + %38 = arith.select %37, %36, %in : f16 + %39 = arith.truncf %38 : f16 to f8E4M3FNUZ + linalg.yield %39, %39 : f8E4M3FNUZ, f8E4M3FNUZ + } -> (tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>) + %34 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%33#1, %29 : tensor<4x?x8x32xf8E4M3FNUZ>, tensor<4x?xi64>) outs(%30 : tensor) { + ^bb0(%arg0: f8E4M3FNUZ, %arg1: f8E4M3FNUZ): + iree_linalg_ext.yield %arg0 : f8E4M3FNUZ + } -> tensor + iree_codegen.store_to_buffer %34, %26 : tensor into memref> + iree_codegen.store_to_buffer %33#0, %27 : tensor<4x?x32x8xf8E4M3FNUZ> into memref<4x?x32x8xf8E4M3FNUZ, strided<[?, 256, 8, 1], offset: ?>, #hal.descriptor_type> + return + } + } + } +} +// CHECK-LABEL: func @multi_result_index_generic_with_scatter_fusion +// CHECK: scf.forall (%arg0, %arg1) in (4, 1) { +// CHECK: vector.transfer_read +// CHECK: arith.mulf {{.*}} vector<4xf16> +// CHECK: arith.truncf {{.*}} vector<4xf16> to vector<4xf8E4M3FNUZ> +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: iree_linalg_ext.scatter