Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter,
}

// Skip the tiling if the size is already 1.
RankedTensorType srcType = packOp.getSourceType();
ShapedType srcType = packOp.getSourceType();
for (auto [idx, val] : llvm::enumerate(tileSizes)) {
if (val && srcType.getDimSize(idx) == 1) {
return;
Expand Down Expand Up @@ -95,7 +95,7 @@ static void tileNonPackedDimsFor5DPUnpackOps(RewriterBase &rewriter,
}

// Skip the tiling if the size is already 1.
RankedTensorType destType = unpackOp.getDestType();
ShapedType destType = unpackOp.getDestType();
for (auto [idx, val] : llvm::enumerate(tileSizes)) {
if (val && destType.getDimSize(idx) == 1) {
return;
Expand Down Expand Up @@ -304,13 +304,14 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern<linalg::PackOp> {
}

Location loc = packOp.getLoc();
auto reducedSrcType =
RankedTensorType::Builder(packOp.getSourceType()).dropDim(srcPos);
RankedTensorType sourceType =
cast<RankedTensorType>(packOp.getSourceType());
auto reducedSrcType = RankedTensorType::Builder(sourceType).dropDim(srcPos);
auto reducedSrc = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, packOp.getSource(), reducedSrcType);

auto reducedDestType =
RankedTensorType::Builder(packOp.getDestType()).dropDim(destPos);
RankedTensorType destType = cast<RankedTensorType>(packOp.getDestType());
auto reducedDestType = RankedTensorType::Builder(destType).dropDim(destPos);
auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, packOp.getDest(), reducedDestType);

Expand Down Expand Up @@ -385,13 +386,14 @@ struct Convert5DUnPackto4DUnPackPattern
}

Location loc = unpackOp.getLoc();
auto reducedSrcType =
RankedTensorType::Builder(unpackOp.getSourceType()).dropDim(srcPos);
RankedTensorType sourceType =
cast<RankedTensorType>(unpackOp.getSourceType());
auto reducedSrcType = RankedTensorType::Builder(sourceType).dropDim(srcPos);
auto reducedSrc = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, unpackOp.getSource(), reducedSrcType);

auto reducedDestType =
RankedTensorType::Builder(unpackOp.getDestType()).dropDim(destPos);
RankedTensorType destType = cast<RankedTensorType>(unpackOp.getDestType());
auto reducedDestType = RankedTensorType::Builder(destType).dropDim(destPos);
auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, unpackOp.getDest(), reducedDestType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ struct EncodingLayoutMaterializerAttrExternalModelBase
}
}
auto packedType =
cast<RankedTensorType>(linalg::PackOp::inferPackedType(
cast<RankedTensorType>(linalg::PackOp::inferPackedTensorType(
type, innerTileSizesVector, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ createTransposeAsTensorPack(
SmallVector<AffineExpr> mapResults(inputMap.getResults());
AffineMap transposedMap;

Value packedOperand = packedInput;
Value packedOperand;
if (!packedInput.getResults().empty()) {
packedOperand = packedInput.getResult();
}
// Collapse the unit dims created by linalg.pack if the pack is just a
// transpose.
if (tilingFactor <= 0) {
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project
Submodule llvm-project updated 662 files
Loading