-
Notifications
You must be signed in to change notification settings - Fork 873
[DispatchCreation] CollapseDimensions patch #18424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
582b647
16964b8
892d2be
c78e02d
399f8d8
14540d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10,7 +10,9 @@ | |||||
| #include "iree/compiler/DispatchCreation/Passes.h" | ||||||
| #include "llvm/ADT/STLExtras.h" | ||||||
| #include "llvm/ADT/SetVector.h" | ||||||
| #include "llvm/ADT/SmallVectorExtras.h" | ||||||
| #include "llvm/Support/Debug.h" | ||||||
| #include "llvm/Support/LogicalResult.h" | ||||||
| #include "llvm/Support/raw_ostream.h" | ||||||
| #include "mlir/Analysis/SliceAnalysis.h" | ||||||
| #include "mlir/Dialect/Affine/Utils.h" | ||||||
|
|
@@ -280,10 +282,15 @@ class CollapseInfo { | |||||
| // Debug print the current operation & reassociation indicies | ||||||
| void dump() const; | ||||||
|
|
||||||
| // Update `collapsableLoops` by taking the set intersection with | ||||||
| // `otherCollapsable` and update the reassociation indicies accordingly. | ||||||
| // Update CollapseInfo to ensure that all dimensions collapsable in `this` are | ||||||
| // also collapsable in `consumerInfo`. This means: | ||||||
| // 1. Any dimension not collapsable in `consumerInfo` should not be | ||||||
| // collapsable in `this` | ||||||
| // 2. For any pair of dimensions in `this`, if they are collapsable in | ||||||
| // `consumerInfo`, they must be collapsable into the same dimension in | ||||||
| // `consumerInfo` to be collapsable into the same dimension in `this`. | ||||||
| // Returns true if the operation modified the number of collapsable loops. | ||||||
| bool updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable); | ||||||
| bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo); | ||||||
|
|
||||||
| // Update `collapsableLoops` by subtracting `uncollapsable` and update the | ||||||
| // reassociation indicies accordingly. | ||||||
|
|
@@ -293,13 +300,18 @@ class CollapseInfo { | |||||
| // Get `collapsableLoops` after applying the transformation provided by `map`. | ||||||
| // Note: doesn't modify `collapsableLoops`, the tranformation is applied to a | ||||||
| // copy. | ||||||
| FailureOr<CollapsableLoopsSet> | ||||||
| getTransformedCollapsableLoops(AffineMap map) const; | ||||||
| CollapsableLoopsSet getTransformedCollapsableLoops(AffineMap map) const; | ||||||
|
|
||||||
| // Clear internal data | ||||||
| void clear() { | ||||||
| // Get `reassociation` after applying the transformation provided by `map`. | ||||||
| SmallVector<ReassociationIndices> | ||||||
| getTransformedReassociation(AffineMap map) const; | ||||||
|
|
||||||
| // Clear internal data and returns if anything changed. | ||||||
| bool clear() { | ||||||
| bool isNotEmpty = reassociation.empty() || collapsableLoops.empty(); | ||||||
| reassociation.clear(); | ||||||
| collapsableLoops.clear(); | ||||||
| return isNotEmpty; | ||||||
| } | ||||||
|
|
||||||
| const CollapsableLoopsSet &getCollapsibleLoops() const { | ||||||
|
|
@@ -386,12 +398,8 @@ void CollapseInfo::updateReassociation() { | |||||
| // map = affine_map<(d0, d1, d2) -> (d1, d2, d5)> | ||||||
| // | ||||||
| // Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}. | ||||||
| FailureOr<CollapseInfo::CollapsableLoopsSet> | ||||||
| CollapseInfo::CollapsableLoopsSet | ||||||
| CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const { | ||||||
| if (!map) { | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| CollapsableLoopsSet transformedLoops; | ||||||
| for (auto index : collapsableLoops) { | ||||||
| assert(index < map.getNumResults() && "index has no valid mapping"); | ||||||
|
|
@@ -405,19 +413,114 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const { | |||||
| return transformedLoops; | ||||||
| } | ||||||
|
|
||||||
| // Update `collapsableLoops` by taking the set intersection with | ||||||
| // `otherCollapsable` and update the reassociation indicies accordingly. | ||||||
| bool CollapseInfo::updateCollapseViaIntersect( | ||||||
| const CollapsableLoopsSet &otherCollapsable) { | ||||||
| CollapsableLoopsSet toRemove; | ||||||
| for (auto elem : collapsableLoops) { | ||||||
| if (!otherCollapsable.contains(elem)) { | ||||||
| toRemove.insert(elem); | ||||||
| SmallVector<ReassociationIndices> | ||||||
| CollapseInfo::getTransformedReassociation(AffineMap map) const { | ||||||
| SmallVector<ReassociationIndices> transformedReassociation( | ||||||
| reassociation.size()); | ||||||
| for (const auto &[i, indicies] : llvm::enumerate(reassociation)) { | ||||||
| for (auto elem : indicies) { | ||||||
| auto dimExpr = dyn_cast<AffineDimExpr>(map.getResult(elem)); | ||||||
| if (!dimExpr) { | ||||||
| break; | ||||||
| } | ||||||
| transformedReassociation[i].push_back(dimExpr.getPosition()); | ||||||
| } | ||||||
| } | ||||||
| collapsableLoops.set_subtract(toRemove); | ||||||
| updateReassociation(); | ||||||
| return toRemove.size(); | ||||||
| return transformedReassociation; | ||||||
| } | ||||||
|
|
||||||
| bool CollapseInfo::updateFromConsumer(OpOperand *operand, | ||||||
| const CollapseInfo &consumerInfo) { | ||||||
| FailureOr<AffineMap> consumerToProducerMap = | ||||||
| getConsumerLoopToProducerLoopsMap(*operand); | ||||||
| if (failed(consumerToProducerMap)) { | ||||||
| return this->clear(); | ||||||
| } | ||||||
|
|
||||||
| CollapsableLoopsSet consumerCollapsable = | ||||||
| consumerInfo.getTransformedCollapsableLoops( | ||||||
| consumerToProducerMap.value()); | ||||||
|
|
||||||
| SmallVector<ReassociationIndices> consumerReassoc = | ||||||
| consumerInfo.getTransformedReassociation(consumerToProducerMap.value()); | ||||||
|
|
||||||
| // Get a map from original index to the index it gets collapsed into | ||||||
| llvm::DenseMap<long, long> consumerCollapseMap; | ||||||
| for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) { | ||||||
| for (const auto elem : indicies) { | ||||||
| consumerCollapseMap[elem] = idx; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Remove all collapsable loops in `producer` that are not collapsable in | ||||||
| // `consumer` (set intersect) | ||||||
| bool didChange = collapsableLoops.remove_if( | ||||||
| [&](long elem) -> bool { return !consumerCollapsable.contains(elem); }); | ||||||
|
|
||||||
| // Now update the reassociation indicies given the updated `collapsableLoops` | ||||||
| // and `consumerCollapsableMap`. | ||||||
| // The idea is to reconstruct the reassociation indicies, and at each index: | ||||||
| // (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add | ||||||
| // `index` to either. | ||||||
| // | ||||||
| // (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps | ||||||
| // `index` to a different collapsed loop then the other indicies, split | ||||||
| // `indicies` and insert `index` into the new one. | ||||||
| // | ||||||
| // For example: | ||||||
| // producer reassociation = [[0, 1], [2, 3]] | ||||||
| // consumer reassociation = [0, 1, 2, 3] | ||||||
| // then, consumer reassociation gets updated to [[0, 1], [2, 3]] because | ||||||
| // [0, 1] and [2, 3] get collapsed into different loops | ||||||
| // | ||||||
| // (3) Otherwise, keep the index | ||||||
| constexpr long kUninitialized = -1; | ||||||
| SmallVector<ReassociationIndices> newReassociation; | ||||||
| for (ReassociationIndicesRef indicies : reassociation) { | ||||||
| // Track the loop index that `indicies` get collapsed into. | ||||||
| long collapseIntoIdx = kUninitialized; | ||||||
|
|
||||||
| // Holds dimensions that should be collapsed together | ||||||
| ReassociationIndices newIndicies; | ||||||
| for (int64_t index : indicies) { | ||||||
| if (!collapsableLoops.contains(index)) { | ||||||
| // (1) Because `index` isn't collapsable, the indicies in `newIndicies` | ||||||
| // are no longer adjacent to the upcoming indicies. If there is >1 index | ||||||
| // to collapse, add it to the new reassociation. Otherwise, discard it | ||||||
| // because there is no dimension to collapse with. | ||||||
| didChange = true; | ||||||
| if (newIndicies.size() > 1) { | ||||||
| newReassociation.push_back(std::move(newIndicies)); | ||||||
| } | ||||||
| newIndicies.clear(); | ||||||
| collapseIntoIdx = kUninitialized; | ||||||
| } else if (collapseIntoIdx == kUninitialized) { | ||||||
| // (2) First occurance of collapsable loop, set collapseIntoIdx. | ||||||
| collapseIntoIdx = consumerCollapseMap.at(index); | ||||||
| newIndicies.push_back(index); | ||||||
| } else if (consumerCollapseMap.at(index) != collapseIntoIdx) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct me if I'm wrong. I think the indices in the ReassociationIndices should be all different. To be more specific, they are just a sequence from 0 to rank-1 if we flatten it. So this can just be
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay I added some comments and also tried to clean up the if statement logic a bit. That was meant to handle the case where
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see the added testcase where there is an elementwise -> generic (2 parallel + 2 reduction). |
||||||
| // (3) `index` is collapsable but not collapsable into the other loops. | ||||||
| // So, split them and look for other loops to collapse `index` into. | ||||||
| didChange = true; | ||||||
| if (newIndicies.size() > 1) { | ||||||
| newReassociation.push_back(std::move(newIndicies)); | ||||||
| } | ||||||
| newIndicies.clear(); | ||||||
| collapseIntoIdx = consumerCollapseMap[index]; | ||||||
| newIndicies.push_back(index); | ||||||
| } else { | ||||||
| // (4) `index` is collapsable and can be collapsed into | ||||||
| // `collapseIntoIndex`. | ||||||
| newIndicies.push_back(index); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| if (newIndicies.size() > 1) { | ||||||
| newReassociation.push_back(newIndicies); | ||||||
| } | ||||||
| } | ||||||
| reassociation = std::move(newReassociation); | ||||||
| return didChange; | ||||||
| } | ||||||
|
|
||||||
| // Update `collapsableLoops` by subtracting `uncollapsable` and update the | ||||||
|
|
@@ -694,12 +797,10 @@ static bool updateConsumersFromProducers( | |||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| CollapseInfo &producerInfo = opMap.find(producerOp)->second; | ||||||
| FailureOr<CollapseInfo::CollapsableLoopsSet> producerCollapsable = | ||||||
| const CollapseInfo &producerInfo = opMap.at(producerOp); | ||||||
| CollapseInfo::CollapsableLoopsSet producerCollapsable = | ||||||
| producerInfo.getTransformedCollapsableLoops(mapping.value()); | ||||||
| if (!failed(producerCollapsable)) { | ||||||
| producerUncollapsable.set_subtract(producerCollapsable.value()); | ||||||
| } | ||||||
| producerUncollapsable.set_subtract(producerCollapsable); | ||||||
|
|
||||||
| didChange |= | ||||||
| consumerInfo.updateCollapseViaSubtract(producerUncollapsable); | ||||||
|
|
@@ -722,7 +823,7 @@ static bool updateProducersFromConsumers( | |||||
| for (auto op : llvm::reverse(slice)) { | ||||||
| auto genericConsumer = cast<linalg::GenericOp>(op); | ||||||
| assert(opMap.contains(genericConsumer)); | ||||||
| const CollapseInfo &consumerInfo = opMap.find(genericConsumer)->second; | ||||||
| const CollapseInfo &consumerInfo = opMap.at(genericConsumer); | ||||||
|
|
||||||
| for (auto operand : genericConsumer.getDpsInputOperands()) { | ||||||
| auto definingOp = operand->get().getDefiningOp(); | ||||||
|
|
@@ -736,26 +837,10 @@ static bool updateProducersFromConsumers( | |||||
|
|
||||||
| // Get a mapping from the consumer's iteration space to the producer's. | ||||||
| CollapseInfo &producerInfo = opMap.find(genericProducer)->second; | ||||||
| FailureOr<AffineMap> consumerToProducerMap = | ||||||
| getConsumerLoopToProducerLoopsMap(*operand); | ||||||
| if (failed(consumerToProducerMap)) { | ||||||
| didChange |= !producerInfo.getCollapsibleLoops().empty(); | ||||||
| producerInfo.clear(); | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Use the map to get the consumer's collapsable loops in terms of the | ||||||
| // producer. | ||||||
| auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops( | ||||||
| consumerToProducerMap.value()); | ||||||
| if (failed(consumerCollapsable)) { | ||||||
| producerInfo.clear(); | ||||||
| continue; | ||||||
| } | ||||||
| // Only loops collapsable in both the consumer and producer may be | ||||||
| // collapsed. | ||||||
| didChange |= | ||||||
| producerInfo.updateCollapseViaIntersect(consumerCollapsable.value()); | ||||||
| didChange |= producerInfo.updateFromConsumer(operand, consumerInfo); | ||||||
| } | ||||||
| } | ||||||
| return didChange; | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.