diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index e5e5806b78f9..050485d60e78 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -195,6 +195,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen", "//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms:IREECodegenTransforms", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets", diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp index e8b66770b0bc..a4e8a0f8606b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" #include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" @@ -318,6 +319,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { controlFusionFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(patterns, controlFusionFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns(patterns, + controlFusionFn); // Add patterns to fold `tensor.empty` operations with its consumers. tensor::populateFoldTensorEmptyPatterns(patterns); // Add some additional patterns that can simplify the IR. @@ -367,6 +370,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { controlFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( bubbleExpandShapePatterns, controlFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and // "pushed-down" `tensor.collapse_shape` operation with their interface // bindings or `tensor.empty` operations. diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 128b254f6a70..5ec758dba5d0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -229,6 +229,7 @@ iree_cc_library( iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + iree::compiler::Codegen::Dialect::Codegen::Transforms::IREECodegenTransforms iree::compiler::Codegen::Dialect::Codegen::Utils iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index febda9283240..20a7a3c8700b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -412,6 +413,8 @@ void PropagateReshapesByExpansionPass::runOnOperation() { }; linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, bubbleUpExpansionControlFn); // Add patterns to do some additional cleanup (on top of canonicalizations // that can be done later) of reshape ops. tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index 70ccb68955d8..467d72156786 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -459,3 +459,205 @@ func.func @no_swap_rank_reducing_slice(%arg0: tensor<3x6xi8>) -> tensor<3xi16> { // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8> // CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] // CHECK-NEXT: iree_tensor_ext.bitcast %[[SLICE]] + +// ----- + +// Test propagating collapse_shape producer through inner_tiled op. +// Using proper 2D matmul indexing maps with MFMA_F32_16x16x16_F16 layout. +// Tensor shapes: LHS[outer_m, outer_k, 16, 16], RHS[outer_k, outer_n, 16, 16], ACC[outer_m, outer_n, 16, 16] +#contraction_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_collapse_through_inner_tiled( + %src: tensor<2x3x4x16x16xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<6x2x16x16xf32> { + // Collapse the first two outer dims of LHS: [2,3] -> [6] + %collapsed = tensor.collapse_shape %src [[0, 1], [2], [3], [4]] + : tensor<2x3x4x16x16xf16> into tensor<6x4x16x16xf16> + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + return %result : tensor<6x2x16x16xf32> +} + +// CHECK-LABEL: func @propagate_collapse_through_inner_tiled +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor<2x3x4x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor<6x2x16x16xf32> +// CHECK: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[SRC]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor<2x3x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<2x3x2x16x16xf32> +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INNER_TILED]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<2x3x2x16x16xf32> into tensor<6x2x16x16xf32> +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test propagating expand_shape consumer through inner_tiled op. +#contraction_accesses2 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_expand_through_inner_tiled( + %lhs: tensor<6x4x16x16xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<2x3x2x16x16xf32> { + %result = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%out) { + indexing_maps = #contraction_accesses2, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + %expanded = tensor.expand_shape %result [[0, 1], [2], [3], [4]] + output_shape [2, 3, 2, 16, 16] : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> + return %expanded : tensor<2x3x2x16x16xf32> +} + +// CHECK-LABEL: func @propagate_expand_through_inner_tiled +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<6x4x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor<6x2x16x16xf32> +// CHECK-DAG: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> +// CHECK-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x4x16x16xf16> into tensor<2x3x4x16x16xf16> +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[EXPANDED_LHS]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor<2x3x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<2x3x2x16x16xf32> +// CHECK: return %[[INNER_TILED]] + +// ----- + +// Test that reshape touching inner dimensions is NOT propagated. +#contraction_accesses3 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @no_propagate_inner_dim_reshape( + %src: tensor<6x4x16x2x8xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<6x2x16x16xf32> { + // Collapsing inner dims [3,4] which are part of inner tile - should NOT propagate. + %collapsed = tensor.collapse_shape %src [[0], [1], [2], [3, 4]] + : tensor<6x4x16x2x8xf16> into tensor<6x4x16x16xf16> + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses3, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + return %result : tensor<6x2x16x16xf32> +} + +// CHECK-LABEL: func @no_propagate_inner_dim_reshape +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape +// CHECK: iree_codegen.inner_tiled ins(%[[COLLAPSED]], + +// ----- + +// Test propagating collapse_shape producer through inner_tiled op with dynamic outer shapes. +#contraction_accesses_dyn1 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_collapse_through_inner_tiled_dynamic( + %src: tensor, %rhs: tensor<4x2x16x16xf16>, %out: tensor) + -> tensor { + // Collapse the first two outer dims of LHS: [?, 3] -> [?*3] + %collapsed = tensor.collapse_shape %src [[0, 1], [2], [3], [4]] + : tensor into tensor + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses_dyn1, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor, tensor<4x2x16x16xf16> into tensor + return %result : tensor +} + +// CHECK-LABEL: func @propagate_collapse_through_inner_tiled_dynamic +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[SRC]], %c0 +// CHECK: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DIM]], 3, 2, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[SRC]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor, tensor<4x2x16x16xf16> into tensor +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INNER_TILED]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor into tensor +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test propagating expand_shape consumer through inner_tiled op with dynamic outer shapes. +#contraction_accesses_dyn2 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_expand_through_inner_tiled_dynamic( + %lhs: tensor, %rhs: tensor<4x2x16x16xf16>, %out: tensor, + %dyn_dim: index) + -> tensor { + %result = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%out) { + indexing_maps = #contraction_accesses_dyn2, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor, tensor<4x2x16x16xf16> into tensor + %expanded = tensor.expand_shape %result [[0, 1], [2], [3], [4]] + output_shape [%dyn_dim, 3, 2, 16, 16] : tensor into tensor + return %expanded : tensor +} + +// CHECK-LABEL: func @propagate_expand_through_inner_tiled_dynamic +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[DYN_DIM:[A-Za-z0-9]+]]: index +// CHECK-DAG: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DYN_DIM]], 3, 2, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DYN_DIM]], 3, 4, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[EXPANDED_LHS]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor, tensor<4x2x16x16xf16> into tensor +// CHECK: return %[[INNER_TILED]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel new file mode 100644 index 000000000000..69c2f3cf089d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel @@ -0,0 +1,35 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "IREECodegenTransforms", + srcs = [ + "ReshapeFusion.cpp", + ], + hdrs = [ + "Transforms.h", + ], + deps = [ + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..05bdbb4f53f4 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt @@ -0,0 +1,33 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + IREECodegenTransforms + HDRS + "Transforms.h" + SRCS + "ReshapeFusion.cpp" + DEPS + LLVMSupport + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRSupport + MLIRTensorDialect + MLIRTransformUtils + MLIRTransforms + iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp new file mode 100644 index 000000000000..f5b99f7ce9fc --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp @@ -0,0 +1,313 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::Codegen { + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +namespace { + +/// Check if an InnerTiledOp can be expanded by propagating a reshape through +/// it. The main real condition is that the inner dimensions of the op are not +/// expanded. Otherwise, we artificially restrict to single result inner_tiled +/// ops for now. +static LogicalResult +canExpandInnerTiledOp(InnerTiledOp op, OpOperand *fusedOperand, + ArrayRef reassociation) { + // Only single result inner_tiled ops are tested or used anywhere, so restrict + // to single result for now. + if (op->getNumResults() != 1) + return failure(); + + // Only outer dims can be expanded because inner dims depend on the `kind` + // attribute's implementation. + int64_t outerRank = + op.getIndexingMapsArray()[fusedOperand->getOperandNumber()] + .getNumResults(); + if (llvm::any_of(reassociation.drop_front(outerRank), + [](ArrayRef group) { return group.size() != 1; })) { + return failure(); + } + return success(); +} + +/// Expand an InnerTiledOp by propagating a reshape through it. +/// `fusedOperand` is the operand connected to the reshape. +/// `reassociation` describes how the collapsed dims map to expanded dims. +/// `expandedShape` is the full expanded shape (outer + inner dims). +/// `expandedValue` is the expanded value to replace the fused operand. +/// `outputReassociations` will be cleared and filled with the reassociation +/// indices for each output, to be used for collapsing the result back to its +/// original shape. +/// The outer dimensions of the InnerTiledOp are expected to not be expanded, +/// which is enforced by the canExpandInnerTiledOp precondition. +static InnerTiledOp expandInnerTiledOp( + InnerTiledOp op, OpOperand *fusedOperand, + ArrayRef reassociation, + ArrayRef expandedShape, Value expandedValue, + SmallVectorImpl> &outputReassociations, + PatternRewriter &rewriter) { + assert(reassociation.size() == + cast(fusedOperand->get().getType()).getRank() && + "expected reassociation rank to match fused operand rank"); + + // Build mapping: iterDim -> list of (expandedIterDim, size). + SmallVector indexingMaps = op.getIndexingMapsArray(); + AffineMap fusedMap = indexingMaps[fusedOperand->getOperandNumber()]; + int64_t numIterDims = fusedMap.getNumDims(); + SmallVector>> iterDimExpansion( + numIterDims); + int64_t expandedDimCounter = 0; + for (auto [resultIdx, expr] : llvm::enumerate(fusedMap.getResults())) { + int64_t iterDim = cast(expr).getPosition(); + for (int64_t expandedOperandIdx : reassociation[resultIdx]) { + iterDimExpansion[iterDim].push_back( + {expandedDimCounter++, expandedShape[expandedOperandIdx]}); + } + } + // Iteration dims outside the fused map's results are independent from the + // expansion, but update their dim position to account for earlier expanded + // dims. Get iteration domain to query sizes of dims not in the fused operand. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + for (int64_t i = 0; i < numIterDims; ++i) { + if (iterDimExpansion[i].empty()) + iterDimExpansion[i].push_back( + {expandedDimCounter++, iterationDomain[i].size}); + } + + SmallVector newIndexingMaps; + SmallVector newOperands; + outputReassociations.clear(); + Location loc = op.getLoc(); + for (OpOperand &operand : op->getOpOperands()) { + AffineMap origMap = indexingMaps[operand.getOperandNumber()]; + auto operandType = cast(operand.get().getType()); + int64_t operandOuterRank = origMap.getNumResults(); + int64_t innerRank = operandType.getRank() - operandOuterRank; + SmallVector newMapResults; + SmallVector operandReassoc; + SmallVector expandedOperandSizes; + int64_t dimCounter = 0; + for (AffineExpr expr : origMap.getResults()) { + int64_t iterDim = cast(expr).getPosition(); + ReassociationIndices group; + for (auto [expandedDim, size] : iterDimExpansion[iterDim]) { + newMapResults.push_back(getAffineDimExpr(expandedDim, op.getContext())); + group.push_back(dimCounter++); + expandedOperandSizes.push_back(size); + } + operandReassoc.push_back(group); + } + // Inner dims are never expanded. + for (int64_t i = 0; i < innerRank; ++i) { + operandReassoc.push_back({dimCounter++}); + expandedOperandSizes.push_back(tensor::getMixedSize( + rewriter, loc, operand.get(), operandOuterRank + i)); + } + newIndexingMaps.push_back( + AffineMap::get(expandedDimCounter, 0, newMapResults, op.getContext())); + + // Store output reassociations for later use. + if (operand.getOperandNumber() >= op.getNumInputs()) { + outputReassociations.push_back(operandReassoc); + } + + if (&operand == fusedOperand) { + newOperands.push_back(expandedValue); + continue; + } + + if (llvm::all_of(operandReassoc, [](ArrayRef group) { + return group.size() == 1; + })) { + newOperands.push_back(operand.get()); + continue; + } + + SmallVector staticShape; + std::tie(staticShape, std::ignore) = + decomposeMixedValues(expandedOperandSizes); + auto expandedType = + RankedTensorType::get(staticShape, operandType.getElementType()); + newOperands.push_back(tensor::ExpandShapeOp::create( + rewriter, loc, expandedType, operand.get(), operandReassoc, + expandedOperandSizes)); + } + + // Expand iterator types. + SmallVector newIterTypes; + for (auto [idx, iterType] : llvm::enumerate(op.getIteratorTypesArray())) { + newIterTypes.append(iterDimExpansion[idx].size(), iterType); + } + + int64_t numInputs = op.getNumInputs(); + SmallVector newInputs(newOperands.begin(), + newOperands.begin() + numInputs); + SmallVector newOutputs(newOperands.begin() + numInputs, + newOperands.end()); + + // Permutations are unchanged, since they are for inner dims, but we need to + // convert from ArrayAttr to SmallVector>. + std::optional>> newPermutations; + if (auto permAttr = op.getPermutations()) { + newPermutations = llvm::map_to_vector( + permAttr->getAsRange(), [](DenseI64ArrayAttr perm) { + return SmallVector(perm.asArrayRef()); + }); + } + + return InnerTiledOp::create(rewriter, loc, newInputs, newOutputs, + newIndexingMaps, newIterTypes, op.getKind(), + op.getSemantics(), newPermutations); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +/// Pattern to propagate a tensor::CollapseShapeOp through a consumer +/// InnerTiledOp. The collapsed dimensions must not include any inner dimensions +/// of the InnerTiledOp. +/// +/// Example: +/// %collapsed = tensor.collapse_shape %src [[0, 1], ...] +/// %result = inner_tiled ins(%collapsed, ...) outs(%out) +/// => +/// %expanded_out = tensor.expand_shape %out [[0, 1], ...] +/// %result = inner_tiled ins(%src, ...) outs(%expanded_out) +/// %collapsed_result = tensor.collapse_shape %result [[0, 1], ...] +struct FoldProducerCollapseShapeWithInnerTiled + : public OpRewritePattern { + FoldProducerCollapseShapeWithInnerTiled(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + if (!collapseOp->hasOneUse()) { + return failure(); + } + OpOperand &use = *collapseOp->use_begin(); + auto innerTiledOp = dyn_cast(use.getOwner()); + if (!innerTiledOp || !controlFn(&use)) { + return failure(); + } + if (failed(canExpandInnerTiledOp(innerTiledOp, &use, + collapseOp.getReassociationIndices()))) { + return failure(); + } + + SmallVector expandedShape = tensor::getMixedSizes( + rewriter, collapseOp.getLoc(), collapseOp.getSrc()); + SmallVector> outputReassociations; + InnerTiledOp expandedOp = expandInnerTiledOp( + innerTiledOp, &use, collapseOp.getReassociationIndices(), expandedShape, + collapseOp.getSrc(), outputReassociations, rewriter); + + SmallVector results; + for (auto [idx, result] : llvm::enumerate(expandedOp.getResults())) { + auto resultType = + cast(innerTiledOp.getResultTypes()[idx]); + results.push_back(tensor::CollapseShapeOp::create( + rewriter, innerTiledOp.getLoc(), resultType, result, + outputReassociations[idx])); + } + rewriter.replaceOp(innerTiledOp, results); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + +/// Pattern to propagate a tensor::ExpandShapeOp consumer back through an +/// InnerTiledOp producer. The expanded dimensions must not include any inner +/// dimensions of the InnerTiledOp. +/// +/// Example: +/// %result = inner_tiled ins(%lhs, ...) outs(%out) +/// %expanded = tensor.expand_shape %result [[0, 1], ...] +/// => +/// %expanded_lhs = tensor.expand_shape %lhs [[0, 1], ...] +/// %expanded_out = tensor.expand_shape %out [[0, 1], ...] +/// %result = inner_tiled ins(%expanded_lhs, ...) outs(%expanded_out) +struct FoldConsumerExpandShapeWithInnerTiled + : public OpRewritePattern { + FoldConsumerExpandShapeWithInnerTiled(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto producerResult = dyn_cast(expandOp.getSrc()); + if (!producerResult) { + return failure(); + } + auto innerTiledOp = dyn_cast(producerResult.getOwner()); + if (!innerTiledOp || !controlFn(&expandOp.getSrcMutable())) { + return failure(); + } + + int64_t resultIdx = producerResult.getResultNumber(); + OpOperand *outputOperand = innerTiledOp.getDpsInitOperand(resultIdx); + if (failed(canExpandInnerTiledOp(innerTiledOp, outputOperand, + expandOp.getReassociationIndices()))) { + return failure(); + } + + // The DPS init will be expanded in the same way as the result, so insert + // the expand_shape on the init first in order to reuse the + // expandInnerTiledOp transformation utility. + SmallVector expandedShape = expandOp.getMixedOutputShape(); + SmallVector staticShape; + std::tie(staticShape, std::ignore) = decomposeMixedValues(expandedShape); + auto sourceType = cast(outputOperand->get().getType()); + auto expandedType = + RankedTensorType::get(staticShape, sourceType.getElementType()); + auto expandedInit = tensor::ExpandShapeOp::create( + rewriter, expandOp.getLoc(), expandedType, outputOperand->get(), + expandOp.getReassociationIndices(), expandedShape); + + SmallVector> outputReassociations; + InnerTiledOp expandedOp = expandInnerTiledOp( + innerTiledOp, outputOperand, expandOp.getReassociationIndices(), + expandedShape, expandedInit, outputReassociations, rewriter); + rewriter.replaceOp(expandOp, expandedOp.getResult(resultIdx)); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Populate Functions +//===----------------------------------------------------------------------===// + +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes) { + patterns.add(patterns.getContext(), + controlFoldingReshapes); +} + +} // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h new file mode 100644 index 000000000000..ef83658add96 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h @@ -0,0 +1,28 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::Codegen { + +//===----------------------------------------------------------------------===// +// Populate functions. +//===----------------------------------------------------------------------===// + +/// Populate patterns to propagate reshapes by expansion. This folds +/// tensor.expand_shape and tensor.collapse_shape ops with their producer +/// and consumer operations respectively. +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes); + +} // namespace mlir::iree_compiler::IREE::Codegen + +#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_