diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp index f0b0322f6b09..9eef94d7b8ac 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp @@ -10,7 +10,9 @@ #include "iree/compiler/DispatchCreation/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Utils.h" @@ -280,10 +282,15 @@ class CollapseInfo { // Debug print the current operation & reassociation indicies void dump() const; - // Update `collapsableLoops` by taking the set intersection with - // `otherCollapsable` and update the reassociation indicies accordingly. + // Update CollapseInfo to ensure that all dimensions collapsable in `this` are + // also collapsable in `consumerInfo`. This means: + // 1. Any dimension not collapsable in `consumerInfo` should not be + // collapsable in `this` + // 2. For any pair of dimensions in `this`, if they are collapsable in + // `consumerInfo`, they must be collapsable into the same dimension in + // `consumerInfo` to be collapsable into the same dimension in `this`. // Returns true if the operation modified the number of collapsable loops. - bool updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable); + bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo); // Update `collapsableLoops` by subtracting `uncollapsable` and update the // reassociation indicies accordingly. @@ -293,13 +300,18 @@ class CollapseInfo { // Get `collapsableLoops` after applying the transformation provided by `map`. // Note: doesn't modify `collapsableLoops`, the tranformation is applied to a // copy. - FailureOr - getTransformedCollapsableLoops(AffineMap map) const; + CollapsableLoopsSet getTransformedCollapsableLoops(AffineMap map) const; - // Clear internal data - void clear() { + // Get `reassociation` after applying the transformation provided by `map`. + SmallVector + getTransformedReassociation(AffineMap map) const; + + // Clear internal data and returns if anything changed. + bool clear() { + bool isNotEmpty = reassociation.empty() || collapsableLoops.empty(); reassociation.clear(); collapsableLoops.clear(); + return isNotEmpty; } const CollapsableLoopsSet &getCollapsibleLoops() const { @@ -386,12 +398,8 @@ void CollapseInfo::updateReassociation() { // map = affine_map<(d0, d1, d2) -> (d1, d2, d5)> // // Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}. -FailureOr +CollapseInfo::CollapsableLoopsSet CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const { - if (!map) { - return failure(); - } - CollapsableLoopsSet transformedLoops; for (auto index : collapsableLoops) { assert(index < map.getNumResults() && "index has no valid mapping"); @@ -405,19 +413,114 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const { return transformedLoops; } -// Update `collapsableLoops` by taking the set intersection with -// `otherCollapsable` and update the reassociation indicies accordingly. -bool CollapseInfo::updateCollapseViaIntersect( - const CollapsableLoopsSet &otherCollapsable) { - CollapsableLoopsSet toRemove; - for (auto elem : collapsableLoops) { - if (!otherCollapsable.contains(elem)) { - toRemove.insert(elem); +SmallVector +CollapseInfo::getTransformedReassociation(AffineMap map) const { + SmallVector transformedReassociation( + reassociation.size()); + for (const auto &[i, indicies] : llvm::enumerate(reassociation)) { + for (auto elem : indicies) { + auto dimExpr = dyn_cast(map.getResult(elem)); + if (!dimExpr) { + break; + } + transformedReassociation[i].push_back(dimExpr.getPosition()); } } - collapsableLoops.set_subtract(toRemove); - updateReassociation(); - return toRemove.size(); + return transformedReassociation; +} + +bool CollapseInfo::updateFromConsumer(OpOperand *operand, + const CollapseInfo &consumerInfo) { + FailureOr consumerToProducerMap = + getConsumerLoopToProducerLoopsMap(*operand); + if (failed(consumerToProducerMap)) { + return this->clear(); + } + + CollapsableLoopsSet consumerCollapsable = + consumerInfo.getTransformedCollapsableLoops( + consumerToProducerMap.value()); + + SmallVector consumerReassoc = + consumerInfo.getTransformedReassociation(consumerToProducerMap.value()); + + // Get a map from original index to the index it gets collapsed into + llvm::DenseMap consumerCollapseMap; + for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) { + for (const auto elem : indicies) { + consumerCollapseMap[elem] = idx; + } + } + + // Remove all collapsable loops in `producer` that are not collapsable in + // `consumer` (set intersect) + bool didChange = collapsableLoops.remove_if( + [&](long elem) -> bool { return !consumerCollapsable.contains(elem); }); + + // Now update the reassociation indicies given the updated `collapsableLoops` + // and `consumerCollapsableMap`. + // The idea is to reconstruct the reassociation indicies, and at each index: + // (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add + // `index` to either. + // + // (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps + // `index` to a different collapsed loop then the other indicies, split + // `indicies` and insert `index` into the new one. + // + // For example: + // producer reassociation = [[0, 1], [2, 3]] + // consumer reassociation = [0, 1, 2, 3] + // then, consumer reassociation gets updated to [[0, 1], [2, 3]] because + // [0, 1] and [2, 3] get collapsed into different loops + // + // (3) Otherwise, keep the index + constexpr long kUninitialized = -1; + SmallVector newReassociation; + for (ReassociationIndicesRef indicies : reassociation) { + // Track the loop index that `indicies` get collapsed into. + long collapseIntoIdx = kUninitialized; + + // Holds dimensions that should be collapsed together + ReassociationIndices newIndicies; + for (int64_t index : indicies) { + if (!collapsableLoops.contains(index)) { + // (1) Because `index` isn't collapsable, the indicies in `newIndicies` + // are no longer adjacent to the upcoming indicies. If there is >1 index + // to collapse, add it to the new reassociation. Otherwise, discard it + // because there is no dimension to collapse with. + didChange = true; + if (newIndicies.size() > 1) { + newReassociation.push_back(std::move(newIndicies)); + } + newIndicies.clear(); + collapseIntoIdx = kUninitialized; + } else if (collapseIntoIdx == kUninitialized) { + // (2) First occurance of collapsable loop, set collapseIntoIdx. + collapseIntoIdx = consumerCollapseMap.at(index); + newIndicies.push_back(index); + } else if (consumerCollapseMap.at(index) != collapseIntoIdx) { + // (3) `index` is collapsable but not collapsable into the other loops. + // So, split them and look for other loops to collapse `index` into. + didChange = true; + if (newIndicies.size() > 1) { + newReassociation.push_back(std::move(newIndicies)); + } + newIndicies.clear(); + collapseIntoIdx = consumerCollapseMap[index]; + newIndicies.push_back(index); + } else { + // (4) `index` is collapsable and can be collapsed into + // `collapseIntoIndex`. + newIndicies.push_back(index); + } + } + + if (newIndicies.size() > 1) { + newReassociation.push_back(newIndicies); + } + } + reassociation = std::move(newReassociation); + return didChange; } // Update `collapsableLoops` by subtracting `uncollapsable` and update the @@ -694,12 +797,10 @@ static bool updateConsumersFromProducers( continue; } - CollapseInfo &producerInfo = opMap.find(producerOp)->second; - FailureOr producerCollapsable = + const CollapseInfo &producerInfo = opMap.at(producerOp); + CollapseInfo::CollapsableLoopsSet producerCollapsable = producerInfo.getTransformedCollapsableLoops(mapping.value()); - if (!failed(producerCollapsable)) { - producerUncollapsable.set_subtract(producerCollapsable.value()); - } + producerUncollapsable.set_subtract(producerCollapsable); didChange |= consumerInfo.updateCollapseViaSubtract(producerUncollapsable); @@ -722,7 +823,7 @@ static bool updateProducersFromConsumers( for (auto op : llvm::reverse(slice)) { auto genericConsumer = cast(op); assert(opMap.contains(genericConsumer)); - const CollapseInfo &consumerInfo = opMap.find(genericConsumer)->second; + const CollapseInfo &consumerInfo = opMap.at(genericConsumer); for (auto operand : genericConsumer.getDpsInputOperands()) { auto definingOp = operand->get().getDefiningOp(); @@ -736,26 +837,10 @@ static bool updateProducersFromConsumers( // Get a mapping from the consumer's iteration space to the producer's. CollapseInfo &producerInfo = opMap.find(genericProducer)->second; - FailureOr consumerToProducerMap = - getConsumerLoopToProducerLoopsMap(*operand); - if (failed(consumerToProducerMap)) { - didChange |= !producerInfo.getCollapsibleLoops().empty(); - producerInfo.clear(); - continue; - } - // Use the map to get the consumer's collapsable loops in terms of the - // producer. - auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops( - consumerToProducerMap.value()); - if (failed(consumerCollapsable)) { - producerInfo.clear(); - continue; - } // Only loops collapsable in both the consumer and producer may be // collapsed. - didChange |= - producerInfo.updateCollapseViaIntersect(consumerCollapsable.value()); + didChange |= producerInfo.updateFromConsumer(operand, consumerInfo); } } return didChange; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir index 5e21f869f82d..e6166dbc7b73 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir @@ -518,3 +518,44 @@ util.func public @propagate_uncollapsable(%arg0: tensor<2x320x128x128xf32>) -> t // CHECK-SAME: ins(%[[VAL2]], %[[VAL1]] : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>) // CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>) // CHECK: flow.return %[[VAL3]] + +// ----- + +util.func public @dequant_contraction(%arg0: tensor<2x32xf32>, %arg1: tensor<2x32x10x16384xf16>) -> tensor<2x32xf32> { + %0 = flow.dispatch.region -> (tensor<2x32xf32>) { + %1 = tensor.empty() : tensor<2x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %2 = tensor.empty() : tensor<2x32x10x16384xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x32x10x16384xf16>) outs(%2 : tensor<2x32x10x16384xf32>) { + ^bb0(%in: f16, %out: f32): + %6 = arith.extf %in : f16 to f32 + linalg.yield %6 : f32 + } -> tensor<2x32x10x16384xf32> + %4 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32> + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %arg0 : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) outs(%4 : tensor<2x32xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.subf %in, %in_0 : f32 + %7 = arith.mulf %6, %6 : f32 + %8 = arith.addf %7, %out : f32 + linalg.yield %8 : f32 + } -> tensor<2x32xf32> + flow.return %5 : tensor<2x32xf32> + } + util.return %0 : tensor<2x32xf32> +} + +// CHECK-LABEL: util.func public @dequant_contraction +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x32xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x32x10x16384xf16> +// CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-DAG: %[[COLLAPSED_ARG1:.+]] = tensor.collapse_shape %[[ARG1]] +// CHECK: flow.dispatch.region +// CHECK: %[[VAL0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSED_ARG1]] : tensor<64x163840xf16>) +// CHECK-SAME: outs(%{{.*}} : tensor<64x163840xf32>) +// CHECK: %[[VAL1:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%[[VAL0]], %[[COLLAPSED_ARG0]] : tensor<64x163840xf32>, tensor<64xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<64xf32>) +// CHECK: flow.return %[[VAL1]]