Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 42 additions & 71 deletions compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"

#define DEBUG_TYPE "tiling-config"
Expand Down Expand Up @@ -169,88 +171,57 @@ SizesAndScalableFlags TilingConfig::getVectorTileSizes() {
return std::make_pair(vectorSizes, scalableFlags);
}

/// Returns a new `LoweringConfigAttr`, with the tile sizes of vector
/// dimensions, set to `sizes`, and the corresponding scalability set to
/// `scalableFlags`.
IREE::Codegen::LoweringConfigAttr
TilingConfig::getLoweringConfigWithNewVectorSizes(
IREE::CPU::LoweringConfigAttr TilingConfig::getLoweringConfigWithNewVectorSizes(
ArrayRef<int64_t> sizes, ArrayRef<bool> scalableFlags) {
unsigned numDims = getNumDimensions();
(void)numDims;
assert(sizes.size() == numDims &&
"expected `sizes` to match number of dimensions");
assert((scalableFlags.empty() || scalableFlags.size() == numDims) &&
"expected `scalableFlags` to match "
"number of dimensions (or be empty)");

// Make a map from tiling levels to vector dims at that level.
std::array<SmallVector<unsigned, 4>, TilingLevel::MaxNumTileLevels>
tilingLevelToDimsMap;
for (unsigned dimPos = 0; dimPos < numDims; ++dimPos) {
auto tilingLevelIndex = getTilingLevelForVectorDimPosition(dimPos);
assert((tilingLevelIndex.has_value() || sizes[dimPos] == 0) &&
"attempting to set vector size for dim with underspecified tiling "
"level (zero is the only valid tile size)");
if (tilingLevelIndex.has_value())
tilingLevelToDimsMap[*tilingLevelIndex].push_back(dimPos);
}

MLIRContext *context = loweringConfig.getContext();
SmallVector<IREE::Codegen::LoweringConfigTilingLevelAttr> tilingLevels;
for (unsigned i = 0, e = getNumTilingLevels(); i < e; ++i) {
tilingLevels.push_back(cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
loweringConfig.getTilingLevelAttr(i)));
}
SmallVector<IREE::Codegen::LoweringConfigTilingLevelAttr> newTilingLevelsList(
tilingLevels.begin(), tilingLevels.end());

// For each vector tiling level:
for (auto [tilingLevelIndex, tilingLevelDims] :
llvm::enumerate(tilingLevelToDimsMap)) {
if (tilingLevelDims.empty())
MLIRContext *ctx = loweringConfig.getContext();
SmallVector<NamedAttribute> items;
for (unsigned i = 0, e = TilingLevel::MaxNumTileLevels; i < e; ++i) {
auto level = static_cast<TilingLevel>(i);
if (!isValidLevel(level)) {
continue;
auto level = tilingLevels[tilingLevelIndex];
SmallVector<int64_t> newSizes(level.getSizes());
SmallVector<bool> newScalableFlags(level.getScalableFlags());
newScalableFlags.resize(numDims);
// 1. Update all the vector sizes within that tiling level.
for (unsigned dimPos : tilingLevelDims) {
std::tie(newSizes[dimPos], newScalableFlags[dimPos]) =
getTileSizeAtIndex(sizes, scalableFlags, dimPos);
}
// 2. Then create a new tiling level attribute for that level.
auto newLevel = IREE::Codegen::LoweringConfigTilingLevelAttr::get(
context, newSizes, level.getInterchange(), newScalableFlags);
newTilingLevelsList[tilingLevelIndex] = newLevel;
}

// Create a new `lowering_config` attribute.
auto newTilingLevels = IREE::Codegen::LoweringConfigTilingLevelsAttr::get(
context, newTilingLevelsList);
return IREE::Codegen::LoweringConfigAttr::get(context, newTilingLevels);
}

/// Returns a list with the tiling levels that can be fused for this
/// configuration.
SmallVector<int64_t> TilingConfig::getFusableLevels() {
switch (getNumTilingLevels()) {
case 0:
return {};
case 1:
// Only distribution level.
return {0};
case 3:
// Only distribution level + vector common parallel levels.
return {0, 1};
case 4:
// Distribution + vector common parallel levels + vector inner parallel
// levels.
return {0, 1, 3};
case 6:
// Distribution + cache parallel levels.
return {0, 1, 3, 5};
default:
llvm_unreachable("Unexpected number of tiling levels");
switch (level) {
case TilingLevel::DistributionTiles:
case TilingLevel::CacheParallelTiles:
case TilingLevel::CacheReductionTiles: {
items.emplace_back(IREE::CPU::getTilingLevelName(level),
getTilingLevelAttr(i));
break;
}
case TilingLevel::VectorCommonParallelTiles:
case TilingLevel::VectorReductionTiles:
case TilingLevel::VectorInnerParallelTiles: {
auto attr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
loweringConfig.getTilingLevelAttr(i));
SmallVector<int64_t> newSizes(attr.getSizes());
SmallVector<bool> newScalableFlags(attr.getScalableFlags());
newScalableFlags.resize(newSizes.size(), false);
for (auto [idx, size] : llvm::enumerate(newSizes)) {
if (size == 0) {
continue;
}
newSizes[idx] = sizes[idx];
newScalableFlags[idx] = scalableFlags[idx];
}
auto newLevel = IREE::Codegen::LoweringConfigTilingLevelAttr::get(
ctx, newSizes, attr.getInterchange(), newScalableFlags);
items.emplace_back(IREE::CPU::getTilingLevelName(level), newLevel);
break;
}
case TilingLevel::MaxNumTileLevels:
case TilingLevel::InvalidLevel:
break;
};
}
return IREE::CPU::LoweringConfigAttr::get(ctx, items);
}

/// Returns the actual level in the configuration for this level of tiling.
Expand Down
27 changes: 17 additions & 10 deletions compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,26 @@ class TilingConfig {
/// Returns a new `LoweringConfigAttr`, with the tile sizes of vector
/// dimensions, set to `sizes`, and the corresponding scalability set to
/// `scalableFlags`.
IREE::Codegen::LoweringConfigAttr
IREE::CPU::LoweringConfigAttr
getLoweringConfigWithNewVectorSizes(ArrayRef<int64_t> sizes,
ArrayRef<bool> scalableFlags = {});

/// Returns a list with the tiling levels that can be fused for this
/// configuration.
SmallVector<int64_t> getFusableLevels();

// TODO(dcaballe): Revisit if these features are ever used.
SmallVector<int64_t> getTileInterchangeSizes(unsigned level) {
auto attr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
loweringConfig.getTilingLevelAttr(level));
return SmallVector<int64_t>(attr.getInterchange());
/// Returns the `level`-th valid tiling attribute. Returns an empty vector if
/// it does not exist.
IREE::Codegen::LoweringConfigTilingLevelAttr
getTilingLevelAttr(int64_t level) {
for (auto [idx, mappedLevel] :
llvm::enumerate(tilingLevelToActualLevelMap)) {
if (mappedLevel == TilingLevel::InvalidLevel) {
continue;
}
if (--level >= 0) {
continue;
}
return cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
loweringConfig.getTilingLevelAttr(mappedLevel));
}
return {};
}

private:
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TosaToArith",
"@llvm-project//mlir:TransformDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ iree_cc_library(
MLIRSCFUtils
MLIRTensorDialect
MLIRTensorTransforms
MLIRTilingInterface
MLIRTosaDialect
MLIRTosaToArith
MLIRTransformDialect
Expand Down
Loading
Loading