Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 131 additions & 46 deletions compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Utils.h"
Expand Down Expand Up @@ -280,10 +282,15 @@ class CollapseInfo {
// Debug print the current operation & reassociation indicies
void dump() const;

// Update `collapsableLoops` by taking the set intersection with
// `otherCollapsable` and update the reassociation indicies accordingly.
// Update CollapseInfo to ensure that all dimensions collapsable in `this` are
// also collapsable in `consumerInfo`. This means:
// 1. Any dimension not collapsable in `consumerInfo` should not be
// collapsable in `this`
// 2. For any pair of dimensions in `this`, if they are collapsable in
// `consumerInfo`, they must be collapsable into the same dimension in
// `consumerInfo` to be collapsable into the same dimension in `this`.
// Returns true if the operation modified the number of collapsable loops.
bool updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable);
bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo);

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
// reassociation indicies accordingly.
Expand All @@ -293,13 +300,18 @@ class CollapseInfo {
// Get `collapsableLoops` after applying the transformation provided by `map`.
// Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
// copy.
FailureOr<CollapsableLoopsSet>
getTransformedCollapsableLoops(AffineMap map) const;
CollapsableLoopsSet getTransformedCollapsableLoops(AffineMap map) const;

// Clear internal data
void clear() {
// Get `reassociation` after applying the transformation provided by `map`.
SmallVector<ReassociationIndices>
getTransformedReassociation(AffineMap map) const;

// Clear internal data and returns if anything changed.
bool clear() {
bool isNotEmpty = reassociation.empty() || collapsableLoops.empty();
reassociation.clear();
collapsableLoops.clear();
return isNotEmpty;
}

const CollapsableLoopsSet &getCollapsibleLoops() const {
Expand Down Expand Up @@ -386,12 +398,8 @@ void CollapseInfo::updateReassociation() {
// map = affine_map<(d0, d1, d2) -> (d1, d2, d5)>
//
// Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}.
FailureOr<CollapseInfo::CollapsableLoopsSet>
CollapseInfo::CollapsableLoopsSet
CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
if (!map) {
return failure();
}

CollapsableLoopsSet transformedLoops;
for (auto index : collapsableLoops) {
assert(index < map.getNumResults() && "index has no valid mapping");
Expand All @@ -405,19 +413,114 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
return transformedLoops;
}

// Update `collapsableLoops` by taking the set intersection with
// `otherCollapsable` and update the reassociation indicies accordingly.
bool CollapseInfo::updateCollapseViaIntersect(
const CollapsableLoopsSet &otherCollapsable) {
CollapsableLoopsSet toRemove;
for (auto elem : collapsableLoops) {
if (!otherCollapsable.contains(elem)) {
toRemove.insert(elem);
SmallVector<ReassociationIndices>
CollapseInfo::getTransformedReassociation(AffineMap map) const {
SmallVector<ReassociationIndices> transformedReassociation(
reassociation.size());
for (const auto &[i, indicies] : llvm::enumerate(reassociation)) {
for (auto elem : indicies) {
auto dimExpr = dyn_cast<AffineDimExpr>(map.getResult(elem));
if (!dimExpr) {
break;
}
transformedReassociation[i].push_back(dimExpr.getPosition());
}
}
collapsableLoops.set_subtract(toRemove);
updateReassociation();
return toRemove.size();
return transformedReassociation;
}

bool CollapseInfo::updateFromConsumer(OpOperand *operand,
const CollapseInfo &consumerInfo) {
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(*operand);
if (failed(consumerToProducerMap)) {
return this->clear();
}

CollapsableLoopsSet consumerCollapsable =
consumerInfo.getTransformedCollapsableLoops(
consumerToProducerMap.value());

SmallVector<ReassociationIndices> consumerReassoc =
consumerInfo.getTransformedReassociation(consumerToProducerMap.value());

// Get a map from original index to the index it gets collapsed into
llvm::DenseMap<long, long> consumerCollapseMap;
for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) {
for (const auto elem : indicies) {
consumerCollapseMap[elem] = idx;
}
}

// Remove all collapsable loops in `producer` that are not collapsable in
// `consumer` (set intersect)
bool didChange = collapsableLoops.remove_if(
[&](long elem) -> bool { return !consumerCollapsable.contains(elem); });

// Now update the reassociation indicies given the updated `collapsableLoops`
// and `consumerCollapsableMap`.
// The idea is to reconstruct the reassociation indicies, and at each index:
// (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
// `index` to either.
//
// (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps
// `index` to a different collapsed loop then the other indicies, split
// `indicies` and insert `index` into the new one.
//
// For example:
// producer reassociation = [[0, 1], [2, 3]]
// consumer reassociation = [0, 1, 2, 3]
// then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
// [0, 1] and [2, 3] get collapsed into different loops
//
// (3) Otherwise, keep the index
constexpr long kUninitialized = -1;
SmallVector<ReassociationIndices> newReassociation;
for (ReassociationIndicesRef indicies : reassociation) {
// Track the loop index that `indicies` get collapsed into.
long collapseIntoIdx = kUninitialized;

// Holds dimensions that should be collapsed together
ReassociationIndices newIndicies;
for (int64_t index : indicies) {
if (!collapsableLoops.contains(index)) {
// (1) Because `index` isn't collapsable, the indicies in `newIndicies`
// are no longer adjacent to the upcoming indicies. If there is >1 index
// to collapse, add it to the new reassociation. Otherwise, discard it
// because there is no dimension to collapse with.
didChange = true;
if (newIndicies.size() > 1) {
newReassociation.push_back(std::move(newIndicies));
}
newIndicies.clear();
collapseIntoIdx = kUninitialized;
} else if (collapseIntoIdx == kUninitialized) {
// (2) First occurance of collapsable loop, set collapseIntoIdx.
collapseIntoIdx = consumerCollapseMap.at(index);
newIndicies.push_back(index);
} else if (consumerCollapseMap.at(index) != collapseIntoIdx) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong. I think the indices in the ReassociationIndices should be all different. To be more specific, they are just a sequence from 0 to rank-1 if we flatten it. So this can just be else?

Suggested change
} else if (consumerCollapseMap.at(index) != collapseIntoIdx) {
} else {

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I added some comments and also tried to clean up the if statement logic a bit. That was meant to handle the case where index is collapsible but current loops.

Copy link
Copy Markdown
Member Author

@IanWood1 IanWood1 Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the added testcase where there is an elementwise -> generic (2 parallel + 2 reduction).

// (3) `index` is collapsable but not collapsable into the other loops.
// So, split them and look for other loops to collapse `index` into.
didChange = true;
if (newIndicies.size() > 1) {
newReassociation.push_back(std::move(newIndicies));
}
newIndicies.clear();
collapseIntoIdx = consumerCollapseMap[index];
newIndicies.push_back(index);
} else {
// (4) `index` is collapsable and can be collapsed into
// `collapseIntoIndex`.
newIndicies.push_back(index);
}
}

if (newIndicies.size() > 1) {
newReassociation.push_back(newIndicies);
}
}
reassociation = std::move(newReassociation);
return didChange;
}

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
Expand Down Expand Up @@ -694,12 +797,10 @@ static bool updateConsumersFromProducers(
continue;
}

CollapseInfo &producerInfo = opMap.find(producerOp)->second;
FailureOr<CollapseInfo::CollapsableLoopsSet> producerCollapsable =
const CollapseInfo &producerInfo = opMap.at(producerOp);
CollapseInfo::CollapsableLoopsSet producerCollapsable =
producerInfo.getTransformedCollapsableLoops(mapping.value());
if (!failed(producerCollapsable)) {
producerUncollapsable.set_subtract(producerCollapsable.value());
}
producerUncollapsable.set_subtract(producerCollapsable);

didChange |=
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
Expand All @@ -722,7 +823,7 @@ static bool updateProducersFromConsumers(
for (auto op : llvm::reverse(slice)) {
auto genericConsumer = cast<linalg::GenericOp>(op);
assert(opMap.contains(genericConsumer));
const CollapseInfo &consumerInfo = opMap.find(genericConsumer)->second;
const CollapseInfo &consumerInfo = opMap.at(genericConsumer);

for (auto operand : genericConsumer.getDpsInputOperands()) {
auto definingOp = operand->get().getDefiningOp();
Expand All @@ -736,26 +837,10 @@ static bool updateProducersFromConsumers(

// Get a mapping from the consumer's iteration space to the producer's.
CollapseInfo &producerInfo = opMap.find(genericProducer)->second;
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(*operand);
if (failed(consumerToProducerMap)) {
didChange |= !producerInfo.getCollapsibleLoops().empty();
producerInfo.clear();
continue;
}

// Use the map to get the consumer's collapsable loops in terms of the
// producer.
auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops(
consumerToProducerMap.value());
if (failed(consumerCollapsable)) {
producerInfo.clear();
continue;
}
// Only loops collapsable in both the consumer and producer may be
// collapsed.
didChange |=
producerInfo.updateCollapseViaIntersect(consumerCollapsable.value());
didChange |= producerInfo.updateFromConsumer(operand, consumerInfo);
}
}
return didChange;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,44 @@ util.func public @propagate_uncollapsable(%arg0: tensor<2x320x128x128xf32>) -> t
// CHECK-SAME: ins(%[[VAL2]], %[[VAL1]] : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>)
// CHECK: flow.return %[[VAL3]]

// -----

util.func public @dequant_contraction(%arg0: tensor<2x32xf32>, %arg1: tensor<2x32x10x16384xf16>) -> tensor<2x32xf32> {
%0 = flow.dispatch.region -> (tensor<2x32xf32>) {
%1 = tensor.empty() : tensor<2x32xf32>
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty() : tensor<2x32x10x16384xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x32x10x16384xf16>) outs(%2 : tensor<2x32x10x16384xf32>) {
^bb0(%in: f16, %out: f32):
%6 = arith.extf %in : f16 to f32
linalg.yield %6 : f32
} -> tensor<2x32x10x16384xf32>
%4 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %arg0 : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) outs(%4 : tensor<2x32xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%6 = arith.subf %in, %in_0 : f32
%7 = arith.mulf %6, %6 : f32
%8 = arith.addf %7, %out : f32
linalg.yield %8 : f32
} -> tensor<2x32xf32>
flow.return %5 : tensor<2x32xf32>
}
util.return %0 : tensor<2x32xf32>
}

// CHECK-LABEL: util.func public @dequant_contraction
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x32xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x32x10x16384xf16>
// CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-DAG: %[[COLLAPSED_ARG1:.+]] = tensor.collapse_shape %[[ARG1]]
// CHECK: flow.dispatch.region
// CHECK: %[[VAL0:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[COLLAPSED_ARG1]] : tensor<64x163840xf16>)
// CHECK-SAME: outs(%{{.*}} : tensor<64x163840xf32>)
// CHECK: %[[VAL1:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%[[VAL0]], %[[COLLAPSED_ARG0]] : tensor<64x163840xf32>, tensor<64xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<64xf32>)
// CHECK: flow.return %[[VAL1]]