Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfl committed Oct 21, 2024
1 parent 46d3066 commit af01815
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
2 changes: 1 addition & 1 deletion include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]

def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> {

let summary = "Gather elements from input based on indices over numtiple dimentions";
let summary = "Gather elements from input based on indices over multiple dimentions";

let description = [{
Gathers elements from a given tensor based on indices that index along multiple dimensions.
Expand Down
25 changes: 18 additions & 7 deletions lib/Conversion/TcpToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ class ConvertGatherOp : public OpConversionPattern<GatherOp> {
}
};

/**
* tcp.gather_nd is lowered to linalg.generic, which allows us to define every
* element in the result tensor using a programmatic expression. The last
* dimension of the indicies tensor is used to index into the input tensor.
*
* For example, we we have an indices tensor of shape 9x4x3x2 and an input
* tensor of shape 5x6x7x8, then the resulting tensor will be of shape
* 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used
* to index into the indicies tensor. Then the last dimension of the index
* tensor (the 2 sized dimension) is used to index into the input tensor.
*/
class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -105,12 +116,12 @@ class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {

auto inputTensor = adaptor.getInput();
auto indicesTensor = adaptor.getIndices();
auto indiciesType = cast<RankedTensorType>(indicesTensor.getType());
auto indicesType = cast<RankedTensorType>(indicesTensor.getType());
auto inputType = cast<RankedTensorType>(inputTensor.getType());
int numGatherAxes = indiciesType.getShape()[indiciesType.getRank() - 1];
int numGatherAxes = indicesType.getShape().back();

SmallVector<Value> resultDimSizes;
for (int i = 0; i < indiciesType.getRank() - 1; i++) {
for (int i = 0; i < indicesType.getRank() - 1; i++) {
resultDimSizes.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, indicesTensor, i));
}
Expand All @@ -127,7 +138,7 @@ class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {

auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
SmallVector<Value> valueIndices, gatherIndices;
for (int i = 0; i < indiciesType.getRank() - 1; i++) {
for (int i = 0; i < indicesType.getRank() - 1; i++) {
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
gatherIndices.push_back(idx);
Expand All @@ -136,14 +147,14 @@ class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
SmallVector<Value> gi = gatherIndices;
auto gidx = b.create<arith::ConstantOp>(loc, b.getIndexAttr(i));
gi.push_back(gidx);
assert(gi.size() == indiciesType.getRank());
assert(gi.size() == indicesType.getRank());
auto idxExtract = b.create<tensor::ExtractOp>(
loc, indiciesType.getElementType(), indicesTensor, gi);
loc, indicesType.getElementType(), indicesTensor, gi);
auto idxCast =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), idxExtract);
valueIndices.push_back(idxCast);
}
for (int i = indiciesType.getRank() - 1; i < resultTensorType.getRank();
for (int i = indicesType.getRank() - 1; i < resultTensorType.getRank();
i++) {
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
Expand Down
30 changes: 15 additions & 15 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc,
}

// figure out what the shape should be for each dim
struct ShapeInfo {
struct DimInfo {
Value value;
bool found = false;
int64_t static_value = 1;
int64_t staticValue = 1;
};
SmallVector<ShapeInfo> shapes(maxRank);
SmallVector<DimInfo> resultShape(maxRank);

for (auto v : ret) {
auto t = cast<RankedTensorType>(v.getType());
Expand All @@ -194,29 +194,29 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc,
if (shape[i] != 1) {
// meaning that this is not something that is already 1, and therefore
// would get broadcast
if (shapes[i].found) {
if (resultShape[i].found) {
// then there are multiple inputs which have non-1 values for this
// axis we should check that the size is the same. If there are
// different shapes then this would result in an error when
// broadcasting
if (shape[i] != ShapedType::kDynamic &&
shapes[i].static_value != ShapedType::kDynamic &&
shapes[i].static_value != shape[i]) {
resultShape[i].staticValue != ShapedType::kDynamic &&
resultShape[i].staticValue != shape[i]) {
// the broadcast failed as there are two different shapes for this
llvm::errs()
<< "failed with broadcasting, have two different shapes "
<< shape[i] << " " << shapes[i].static_value << "\n";
<< shape[i] << " " << resultShape[i].staticValue << "\n";
return {};
}
} else {
shapes[i].found = true;
resultShape[i].found = true;
if (shape[i] == ShapedType::kDynamic) {
shapes[i].value = rewriter.create<tensor::DimOp>(loc, v, i);
shapes[i].static_value = ShapedType::kDynamic;
resultShape[i].value = rewriter.create<tensor::DimOp>(loc, v, i);
resultShape[i].staticValue = ShapedType::kDynamic;
} else {
shapes[i].value = rewriter.create<arith::ConstantOp>(
resultShape[i].value = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(shape[i]));
shapes[i].static_value = shape[i];
resultShape[i].staticValue = shape[i];
}
}
}
Expand All @@ -231,11 +231,11 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> sizes;
SmallVector<int64_t> staticShape;
for (int64_t j = 0; j < maxRank; j++) {
if (t.getShape()[j] == 1 && shapes[j].found) {
if (t.getShape()[j] == 1 && resultShape[j].found) {
axes.push_back(j);
sizes.push_back(shapes[j].value);
sizes.push_back(resultShape[j].value);
}
staticShape.push_back(shapes[j].static_value);
staticShape.push_back(resultShape[j].staticValue);
}
if (!axes.empty()) {
// there is something to broadcast here, so add the op
Expand Down

0 comments on commit af01815

Please sign in to comment.