diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index 50e8d44bca39..f66cf9c88b5a 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -29,6 +29,7 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #define DEBUG_TYPE "iree-dispatch-creation-fold-unit-extent-dims" @@ -184,7 +185,12 @@ struct DropUnitDimsFromCollapseOfExpand // needed. Both ops require a non-identity reassociation (i.e. they can't be // no-ops). Value newExpanded = expandOp.getSrc(); - if (!llvm::all_of(newExpandReassoc, + // Special case: If the source has no shape (e.g., `tensor`), we + // produce an empty reassociation, but still need to insert an + // `expand_shape`. + bool noSrcShape = srcShape.empty() && !newInterShape.empty(); + if (noSrcShape || + !llvm::all_of(newExpandReassoc, llvm::hasSingleElement)) { newExpanded = tensor::ExpandShapeOp::create( rewriter, expandOp.getLoc(), @@ -258,8 +264,7 @@ struct FoldUnitDimsFromExtractOp : OpRewritePattern { } // namespace -static void -populatefoldUnitDimsPatterns(RewritePatternSet &foldUnitDimsPatterns) { +static linalg::ControlDropUnitDims getControlDropUnitDimsOptions() { linalg::ControlDropUnitDims options; auto defaultFn = options.controlFn; @@ -273,14 +278,25 @@ populatefoldUnitDimsPatterns(RewritePatternSet &foldUnitDimsPatterns) { } return defaultFn(op); }; + return options; +} + +/// Populate patterns for the core unit-extent dim folding transformations. +/// These patterns are applied with a walk-based driver. +static void populateFoldUnitDimsPatterns(RewritePatternSet &patterns, + linalg::ControlDropUnitDims &options) { + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); + IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(patterns, options); +} - linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); - IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, - options); - linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); - foldUnitDimsPatterns - .insert( - foldUnitDimsPatterns.getContext()); +/// Populate canonicalization patterns that clean up after unit-extent dim +/// folding. These patterns are applied with a greedy driver. +static void populateFoldUnitDimsCanonicalizationPatterns( + RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) { + patterns.insert( + patterns.getContext()); + linalg::populateMoveInitOperandsToInputPattern(patterns); + linalg::populateFoldUnitExtentDimsCanonicalizationPatterns(patterns, options); } static LogicalResult @@ -405,25 +421,42 @@ void FoldUnitExtentDimsPass::runOnOperation() { } }); - RewritePatternSet foldUnitDimsPatterns(context); - populatefoldUnitDimsPatterns(foldUnitDimsPatterns); - GreedyRewriteConfig rewriterConfig; - rewriterConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit); - if (failed(applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns), - rewriterConfig))) { - return signalPassFailure(); + linalg::ControlDropUnitDims options = getControlDropUnitDimsOptions(); + // Apply fold unit extent dims patterns with walk-based driver. + { + RewritePatternSet patterns(context); + populateFoldUnitDimsPatterns(patterns, options); + walkAndApplyPatterns(moduleOp, std::move(patterns)); + } + + // Apply canonicalization patterns with greedy driver. + { + RewritePatternSet patterns(context); + populateFoldUnitDimsCanonicalizationPatterns(patterns, options); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + return signalPassFailure(); + } } } void FoldUnitExtentDimsForFuncPass::runOnOperation() { MLIRContext *context = &getContext(); - RewritePatternSet foldUnitDimsPatterns(context); - populatefoldUnitDimsPatterns(foldUnitDimsPatterns); - GreedyRewriteConfig rewriterConfig; - rewriterConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit); - if (failed(applyPatternsGreedily( - getOperation(), std::move(foldUnitDimsPatterns), rewriterConfig))) { - return signalPassFailure(); + Operation *op = getOperation(); + linalg::ControlDropUnitDims options = getControlDropUnitDimsOptions(); + // Apply fold unit extent dims patterns with walk-based driver. + { + RewritePatternSet patterns(context); + populateFoldUnitDimsPatterns(patterns, options); + walkAndApplyPatterns(op, std::move(patterns)); + } + + // Apply canonicalization patterns with greedy driver. + { + RewritePatternSet patterns(context); + populateFoldUnitDimsCanonicalizationPatterns(patterns, options); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir index ea5e4d095823..51b67d90ae33 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir @@ -347,6 +347,21 @@ util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040 // ----- +// This test considers the case where we have no shape at all on the source for the expand, +// but still need one unit dim for the output. +util.func @collapse_of_expand_no_source_shape(%arg0: tensor) -> tensor<1xf32> { + %expanded = tensor.expand_shape %arg0 [] output_shape[1, 1] : tensor into tensor<1x1xf32> + %collapsed = tensor.collapse_shape %expanded [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> + util.return %collapsed : tensor<1xf32> +} +// CHECK-LABEL: util.func public @collapse_of_expand_no_source_shape +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: tensor into tensor<1xf32> +// CHECK: util.return %[[EXPAND]] : tensor<1xf32> + +// ----- + util.func @fold_unit_dims_from_extract_leading(%arg0: tensor<1x4x8xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 { %extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<1x4x8xf32> util.return %extracted : f32 diff --git a/runtime/src/iree/builtins/device/device_generic.c b/runtime/src/iree/builtins/device/device_generic.c index beb615da5280..684427826a03 100644 --- a/runtime/src/iree/builtins/device/device_generic.c +++ b/runtime/src/iree/builtins/device/device_generic.c @@ -161,4 +161,16 @@ IREE_DEVICE_EXPORT float __truncsfhf2(float param) { return *((float *)&ret); } +IREE_DEVICE_EXPORT double __extendhfdf2(float param) { + return (double)__extendhfsf2(param); +} + +IREE_DEVICE_EXPORT float __truncdfhf2(double param) { + return __truncsfhf2((float)param); +} + +IREE_DEVICE_EXPORT double fma(double x, double y, double z) { + return x * y + z; +} + #endif // IREE_DEVICE_STANDALONE diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json index caf18983ac1e..0349a08580c3 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json @@ -32,9 +32,6 @@ "onnx/node/generated/test_basic_deform_conv_without_padding", "onnx/node/generated/test_bernoulli_seed", "onnx/node/generated/test_bernoulli_seed_expanded", - "onnx/node/generated/test_cast_DOUBLE_to_FLOAT16", - "onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16", - "onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16_expanded", "onnx/node/generated/test_col2im", "onnx/node/generated/test_col2im_5d", "onnx/node/generated/test_col2im_dilations", diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json index b30ab118f5b3..928b6e993055 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json @@ -33,9 +33,6 @@ "onnx/node/generated/test_basic_deform_conv_without_padding", "onnx/node/generated/test_bernoulli_seed", "onnx/node/generated/test_bernoulli_seed_expanded", - "onnx/node/generated/test_cast_DOUBLE_to_FLOAT16", - "onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16", - "onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16_expanded", "onnx/node/generated/test_col2im", "onnx/node/generated/test_col2im_5d", "onnx/node/generated/test_col2im_dilations", diff --git a/third_party/llvm-project b/third_party/llvm-project index 53071dd81cdb..1c023cbab959 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 53071dd81cdb2992b59b743e25a9daf6c46f2611 +Subproject commit 1c023cbab95915929b2ebcdc48b2d682921e067a diff --git a/third_party/torch-mlir b/third_party/torch-mlir index ac7b5f5d0feb..3cebce2bd718 160000 --- a/third_party/torch-mlir +++ b/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit ac7b5f5d0feb9c07d56ec6a19cb66483b0780f53 +Subproject commit 3cebce2bd718eadccc2cba385400cb3d350134ae