Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-dispatch-creation-fold-unit-extent-dims"

Expand Down Expand Up @@ -184,7 +185,12 @@ struct DropUnitDimsFromCollapseOfExpand
// needed. Both ops require a non-identity reassociation (i.e. they can't be
// no-ops).
Value newExpanded = expandOp.getSrc();
if (!llvm::all_of(newExpandReassoc,
// Special case: If the source has no shape (e.g., `tensor<f32>`), we
// produce an empty reassociation, but still need to insert an
// `expand_shape`.
bool noSrcShape = srcShape.empty() && !newInterShape.empty();
if (noSrcShape ||
!llvm::all_of(newExpandReassoc,
llvm::hasSingleElement<ReassociationIndicesRef>)) {
newExpanded = tensor::ExpandShapeOp::create(
rewriter, expandOp.getLoc(),
Expand Down Expand Up @@ -258,8 +264,7 @@ struct FoldUnitDimsFromExtractOp : OpRewritePattern<tensor::ExtractOp> {

} // namespace

static void
populatefoldUnitDimsPatterns(RewritePatternSet &foldUnitDimsPatterns) {
static linalg::ControlDropUnitDims getControlDropUnitDimsOptions() {
linalg::ControlDropUnitDims options;
auto defaultFn = options.controlFn;

Expand All @@ -273,14 +278,25 @@ populatefoldUnitDimsPatterns(RewritePatternSet &foldUnitDimsPatterns) {
}
return defaultFn(op);
};
return options;
}

/// Populate patterns for the core unit-extent dim folding transformations.
/// These patterns are applied with a walk-based driver.
static void populateFoldUnitDimsPatterns(RewritePatternSet &patterns,
linalg::ControlDropUnitDims &options) {
linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(patterns, options);
}

linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
foldUnitDimsPatterns
.insert<DropUnitDimsFromCollapseOfExpand, FoldUnitDimsFromExtractOp>(
foldUnitDimsPatterns.getContext());
/// Populate canonicalization patterns that clean up after unit-extent dim
/// folding. These patterns are applied with a greedy driver.
static void populateFoldUnitDimsCanonicalizationPatterns(
RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
patterns.insert<DropUnitDimsFromCollapseOfExpand, FoldUnitDimsFromExtractOp>(
patterns.getContext());
linalg::populateMoveInitOperandsToInputPattern(patterns);
linalg::populateFoldUnitExtentDimsCanonicalizationPatterns(patterns, options);
}

static LogicalResult
Expand Down Expand Up @@ -405,25 +421,42 @@ void FoldUnitExtentDimsPass::runOnOperation() {
}
});

RewritePatternSet foldUnitDimsPatterns(context);
populatefoldUnitDimsPatterns(foldUnitDimsPatterns);
GreedyRewriteConfig rewriterConfig;
rewriterConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit);
if (failed(applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns),
rewriterConfig))) {
return signalPassFailure();
linalg::ControlDropUnitDims options = getControlDropUnitDimsOptions();
// Apply fold unit extent dims patterns with walk-based driver.
{
RewritePatternSet patterns(context);
populateFoldUnitDimsPatterns(patterns, options);
walkAndApplyPatterns(moduleOp, std::move(patterns));
}

// Apply canonicalization patterns with greedy driver.
{
RewritePatternSet patterns(context);
populateFoldUnitDimsCanonicalizationPatterns(patterns, options);
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}

void FoldUnitExtentDimsForFuncPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet foldUnitDimsPatterns(context);
populatefoldUnitDimsPatterns(foldUnitDimsPatterns);
GreedyRewriteConfig rewriterConfig;
rewriterConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit);
if (failed(applyPatternsGreedily(
getOperation(), std::move(foldUnitDimsPatterns), rewriterConfig))) {
return signalPassFailure();
Operation *op = getOperation();
linalg::ControlDropUnitDims options = getControlDropUnitDimsOptions();
// Apply fold unit extent dims patterns with walk-based driver.
{
RewritePatternSet patterns(context);
populateFoldUnitDimsPatterns(patterns, options);
walkAndApplyPatterns(op, std::move(patterns));
}

// Apply canonicalization patterns with greedy driver.
{
RewritePatternSet patterns(context);
populateFoldUnitDimsCanonicalizationPatterns(patterns, options);
if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
return signalPassFailure();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,21 @@ util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040

// -----

// This test considers the case where we have no shape at all on the source for the expand,
// but still need one unit dim for the output.
util.func @collapse_of_expand_no_source_shape(%arg0: tensor<f32>) -> tensor<1xf32> {
%expanded = tensor.expand_shape %arg0 [] output_shape[1, 1] : tensor<f32> into tensor<1x1xf32>
%collapsed = tensor.collapse_shape %expanded [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
util.return %collapsed : tensor<1xf32>
}
// CHECK-LABEL: util.func public @collapse_of_expand_no_source_shape
// CHECK-SAME: %[[ARG0:.+]]: tensor<f32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: tensor<f32> into tensor<1xf32>
// CHECK: util.return %[[EXPAND]] : tensor<1xf32>

// -----

util.func @fold_unit_dims_from_extract_leading(%arg0: tensor<1x4x8xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 {
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<1x4x8xf32>
util.return %extracted : f32
Expand Down
12 changes: 12 additions & 0 deletions runtime/src/iree/builtins/device/device_generic.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,16 @@ IREE_DEVICE_EXPORT float __truncsfhf2(float param) {
return *((float *)&ret);
}

IREE_DEVICE_EXPORT double __extendhfdf2(float param) {
return (double)__extendhfsf2(param);
}

IREE_DEVICE_EXPORT float __truncdfhf2(double param) {
return __truncsfhf2((float)param);
}

IREE_DEVICE_EXPORT double fma(double x, double y, double z) {
return x * y + z;
}

#endif // IREE_DEVICE_STANDALONE
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
"onnx/node/generated/test_basic_deform_conv_without_padding",
"onnx/node/generated/test_bernoulli_seed",
"onnx/node/generated/test_bernoulli_seed_expanded",
"onnx/node/generated/test_cast_DOUBLE_to_FLOAT16",
"onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16",
"onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16_expanded",
"onnx/node/generated/test_col2im",
"onnx/node/generated/test_col2im_5d",
"onnx/node/generated/test_col2im_dilations",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
"onnx/node/generated/test_basic_deform_conv_without_padding",
"onnx/node/generated/test_bernoulli_seed",
"onnx/node/generated/test_bernoulli_seed_expanded",
"onnx/node/generated/test_cast_DOUBLE_to_FLOAT16",
"onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16",
"onnx/node/generated/test_castlike_DOUBLE_to_FLOAT16_expanded",
"onnx/node/generated/test_col2im",
"onnx/node/generated/test_col2im_5d",
"onnx/node/generated/test_col2im_dilations",
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project
Submodule llvm-project updated 655 files
Loading