diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 4c900b3..0bafcba 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -664,7 +664,7 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"] def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> { - let summary = "Gather elements from input based on indices over numtiple dimentions"; + let summary = "Gather elements from input based on indices over multiple dimentions"; let description = [{ Gathers elements from a given tensor based on indices that index along multiple dimensions. diff --git a/lib/Conversion/TcpToLinalg/DataMovement.cpp b/lib/Conversion/TcpToLinalg/DataMovement.cpp index c3a1329..c9e7442 100644 --- a/lib/Conversion/TcpToLinalg/DataMovement.cpp +++ b/lib/Conversion/TcpToLinalg/DataMovement.cpp @@ -91,6 +91,17 @@ class ConvertGatherOp : public OpConversionPattern { } }; +/** + * tcp.gather_nd is lowered to linalg.generic, which allows us to define every + * element in the result tensor using a programmatic expression. The last + * dimension of the indicies tensor is used to index into the input tensor. + * + * For example, we we have an indices tensor of shape 9x4x3x2 and an input + * tensor of shape 5x6x7x8, then the resulting tensor will be of shape + * 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used + * to index into the indicies tensor. Then the last dimension of the index + * tensor (the 2 sized dimension) is used to index into the input tensor. + */ class ConvertGatherNDOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -105,12 +116,12 @@ class ConvertGatherNDOp : public OpConversionPattern { auto inputTensor = adaptor.getInput(); auto indicesTensor = adaptor.getIndices(); - auto indiciesType = cast(indicesTensor.getType()); + auto indicesType = cast(indicesTensor.getType()); auto inputType = cast(inputTensor.getType()); - int numGatherAxes = indiciesType.getShape()[indiciesType.getRank() - 1]; + int numGatherAxes = indicesType.getShape().back(); SmallVector resultDimSizes; - for (int i = 0; i < indiciesType.getRank() - 1; i++) { + for (int i = 0; i < indicesType.getRank() - 1; i++) { resultDimSizes.push_back( rewriter.createOrFold(loc, indicesTensor, i)); } @@ -127,7 +138,7 @@ class ConvertGatherNDOp : public OpConversionPattern { auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { SmallVector valueIndices, gatherIndices; - for (int i = 0; i < indiciesType.getRank() - 1; i++) { + for (int i = 0; i < indicesType.getRank() - 1; i++) { auto idx = b.create(loc, b.getIndexType(), b.getI64IntegerAttr(i)); gatherIndices.push_back(idx); @@ -136,14 +147,14 @@ class ConvertGatherNDOp : public OpConversionPattern { SmallVector gi = gatherIndices; auto gidx = b.create(loc, b.getIndexAttr(i)); gi.push_back(gidx); - assert(gi.size() == indiciesType.getRank()); + assert(gi.size() == indicesType.getRank()); auto idxExtract = b.create( - loc, indiciesType.getElementType(), indicesTensor, gi); + loc, indicesType.getElementType(), indicesTensor, gi); auto idxCast = b.create(loc, b.getIndexType(), idxExtract); valueIndices.push_back(idxCast); } - for (int i = indiciesType.getRank() - 1; i < resultTensorType.getRank(); + for (int i = indicesType.getRank() - 1; i < resultTensorType.getRank(); i++) { auto idx = b.create(loc, b.getIndexType(), b.getI64IntegerAttr(i)); diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index da4c8ef..65fe948 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -180,12 +180,12 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, } // figure out what the shape should be for each dim - struct ShapeInfo { + struct DimInfo { Value value; bool found = false; - int64_t static_value = 1; + int64_t staticValue = 1; }; - SmallVector shapes(maxRank); + SmallVector resultShape(maxRank); for (auto v : ret) { auto t = cast(v.getType()); @@ -194,29 +194,29 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, if (shape[i] != 1) { // meaning that this is not something that is already 1, and therefore // would get broadcast - if (shapes[i].found) { + if (resultShape[i].found) { // then there are multiple inputs which have non-1 values for this // axis we should check that the size is the same. If there are // different shapes then this would result in an error when // broadcasting if (shape[i] != ShapedType::kDynamic && - shapes[i].static_value != ShapedType::kDynamic && - shapes[i].static_value != shape[i]) { + resultShape[i].staticValue != ShapedType::kDynamic && + resultShape[i].staticValue != shape[i]) { // the broadcast failed as there are two different shapes for this llvm::errs() << "failed with broadcasting, have two different shapes " - << shape[i] << " " << shapes[i].static_value << "\n"; + << shape[i] << " " << resultShape[i].staticValue << "\n"; return {}; } } else { - shapes[i].found = true; + resultShape[i].found = true; if (shape[i] == ShapedType::kDynamic) { - shapes[i].value = rewriter.create(loc, v, i); - shapes[i].static_value = ShapedType::kDynamic; + resultShape[i].value = rewriter.create(loc, v, i); + resultShape[i].staticValue = ShapedType::kDynamic; } else { - shapes[i].value = rewriter.create( + resultShape[i].value = rewriter.create( loc, rewriter.getIndexAttr(shape[i])); - shapes[i].static_value = shape[i]; + resultShape[i].staticValue = shape[i]; } } } @@ -231,11 +231,11 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, SmallVector sizes; SmallVector staticShape; for (int64_t j = 0; j < maxRank; j++) { - if (t.getShape()[j] == 1 && shapes[j].found) { + if (t.getShape()[j] == 1 && resultShape[j].found) { axes.push_back(j); - sizes.push_back(shapes[j].value); + sizes.push_back(resultShape[j].value); } - staticShape.push_back(shapes[j].static_value); + staticShape.push_back(resultShape[j].staticValue); } if (!axes.empty()) { // there is something to broadcast here, so add the op