-
Notifications
You must be signed in to change notification settings - Fork 15.9k
[mlir][shard,mpi] Lowering shard.allgather to MPI #177202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg Author: Frank Schlimbach (fschlimb) Changes
Patch is 21.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/177202.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 5e68f75ee08bf..6ef7c72d305ee 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
```
}];
let arguments = !con(commonArgs, (ins
- AnyNon0RankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
IndexAttr:$gather_axis
));
let results = (outs
- AnyNon0RankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index 57d65e687ea35..1ddd1985389bc 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder);
// Get process linear index along the given grid axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ArrayRef<GridAxis> gridAxes = {});
// Get process linear index from a multi-index along the given grid axes .
TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder);
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes = {});
} // namespace shard
} // namespace mlir
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index b0831dc05abb1..1865914de9d84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
@@ -507,103 +508,147 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
}
}
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename CommOp>
+struct CommOpPattern : public OpConversionPattern<CommOp> {
+ using OpConversionPattern<CommOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SymbolTableCollection symbolTableCollection;
- auto grid = adaptor.getGrid();
- mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
- if (!gridOp)
- return op->emitError() << "No grid found for AllReduceOp";
- if (ShapedType::isDynamicShape(gridOp.getShape()))
- return op->emitError()
- << "Dynamic grid shape not supported in AllReduceOp";
-
- ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
- Value input = adaptor.getInput();
- auto inputShape = cast<ShapedType>(input.getType()).getShape();
+ MemRefType getMemrefType(ShapedType tensorType) const {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ }
+ Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+ auto itype = input.getType();
// If the source is a memref, cast it to a tensor.
- if (isa<RankedTensorType>(input.getType())) {
- auto memrefType = MemRefType::get(
- inputShape, cast<ShapedType>(input.getType()).getElementType());
+ if (isa<RankedTensorType>(itype)) {
+ auto memrefType = getMemrefType(cast<ShapedType>(itype));
input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+ } else {
+ assert(isa<MemRefType>(itype) &&
+ "expected input to be of MemRefType or TensorType");
}
- MemRefType inType = cast<MemRefType>(input.getType());
-
- // Get the actual shape to allocate the buffer.
- SmallVector<OpFoldResult> shape(inType.getRank());
- for (auto i = 0; i < inType.getRank(); ++i) {
- auto s = inputShape[i];
- if (ShapedType::isDynamic(s))
- shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
- else
- shape[i] = iBuilder.getIndexAttr(s);
- }
+ return input;
+ }
- // Allocate buffer and copy input to buffer.
- Value buffer = memref::AllocOp::create(
- iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
- linalg::CopyOp::create(iBuilder, input, buffer);
+ FailureOr<GridOp> checkGrid(CommOp op,
+ SymbolTableCollection &symbolTableCollection,
+ bool allowDynamic = false) const {
+ GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "Missing grid symbol.";
+ if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
+ return op->emitError() << "Dynamic grid shape not supported.";
+ return gridOp;
+ }
- // Get an MPI_Comm_split for the AllReduce operation.
+ Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
+ ImplicitLocOpBuilder &iBuilder) const {
+ // Get an MPI_Comm_split for a given grid and axes.
// The color is the linear index of the process in the grid along the
- // non-reduced axes. The key is the linear index of the process in the grid
- // along the reduced axes.
- SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
- iBuilder.getIndexType());
- SmallVector<Value> myMultiIndex =
- ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
- .getResult();
- Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
- SmallVector<Value> multiKey(myMultiIndex.size(), zero);
-
- auto redAxes = adaptor.getGridAxes();
- for (auto axis : redAxes) {
- multiKey[axis] = myMultiIndex[axis];
- myMultiIndex[axis] = zero;
+ // non-'grid-axes'. The key is the linear index of the process in the grid
+ // along the grid-axes.
+ SmallVector<GridAxis> otherAxes;
+ for (GridAxis i = 0; i < static_cast<GridAxis>(gridOp.getShape().size());
+ ++i) {
+ if (!llvm::is_contained(gridAxes, i))
+ otherAxes.emplace_back(i);
}
+ SmallVector<Type> indexResultTypes(otherAxes.size(),
+ iBuilder.getIndexType());
+
Value color =
- createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+ createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+
+ Value key =
+ createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
- auto commType = mpi::CommType::get(op->getContext());
+ auto commType = mpi::CommType::get(gridOp->getContext());
Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
- auto comm =
- mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
- .getNewcomm();
-
- Value buffer1d = buffer;
- // Collapse shape to 1d if needed
- if (inType.getRank() > 1) {
- ReassociationIndices reassociation(inType.getRank());
- std::iota(reassociation.begin(), reassociation.end(), 0);
- buffer1d = memref::CollapseShapeOp::create(
- iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
- }
+ return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+ .getNewcomm();
+ }
+};
+struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
+ using CommOpPattern::CommOpPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+
+ // Allocate buffer and copy input to buffer.
+ Value buffer = memref::AllocOp::create(iBuilder, outType);
+ linalg::CopyOp::create(iBuilder, input, buffer);
+ // Get the right communicator
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
// Create the MPI AllReduce operation.
- mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+ mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
getMPIReductionOp(adaptor.getReductionAttr()),
comm);
- // If the destination is a memref, cast it to a tensor
+ // If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
true);
-
rewriter.replaceOp(op, buffer);
return success();
}
};
+struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
+ using CommOpPattern::CommOpPattern;
+
+ LogicalResult
+ matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+
+ // Get the right communicator
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
+ // Allocate output buffer
+ Value output = memref::AllocOp::create(iBuilder, outType);
+ // Create the MPI AllGather operation.
+ mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+
+ // If the destination is a tensor, cast it to a tensor
+ if (isa<RankedTensorType>(op.getType()))
+ output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
+ true);
+ rewriter.replaceOp(op, output);
+ return success();
+ }
+};
+
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
@@ -895,8 +940,8 @@ struct ConvertShardToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
- ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
- ctxt);
+ ConvertAllGatherOp, ConvertAllReduceOp,
+ ConvertProcessLinearIndexOp>(typeConverter, ctxt);
SymbolTableCollection stc;
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
populateAllSliceOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 0ae2a9cc0318c..d0165595f9fb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
- gridOp.getSymName(), reductionGridAxes, builder);
+ builder, gridOp.getSymName(), reductionGridAxes);
Value zero = arith::ConstantIndexOp::create(builder, 0);
Value isLeadProcess = arith::CmpIOp::create(
builder, builder.getI1Type(), arith::CmpIPredicate::eq,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index b433b8b0be7b2..835bc443d4b2a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
}
TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder) {
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes) {
Operation::result_range processGroupShape =
GridShapeOp::create(builder, grid, gridAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
return cast<TypedValue<IndexType>>(res);
}
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder) {
+TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
+ StringRef grid,
+ ArrayRef<GridAxis> gridAxes) {
return createProcessLinearIndex(
- grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
- gridAxes, builder);
+ builder, grid,
+ ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+ gridAxes);
}
} // namespace mlir::shard
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index a0b6bfaf6fd3d..d4741102e4d3f 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
- // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
// CHECK: return [[v2]] : tensor<3x4xf32>
@@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_memref(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
// CHECK: return [[valloc]] : memref<3x4xf32>
return %0 : memref<3x4xf32>
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_new_type(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
// CHECK: return [[valloc]] : memref<3x4xf64>
return %0 : memref<3x4xf64>
}
+
+ // CHECK-LABEL: func @allgather_tensor
+ func.func @allgather_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+ // CHECK-SAME: -> tensor<3x20xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+ // CHECK: [[vc2_i32:%.*]]...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) Changes
Patch is 21.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/177202.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 5e68f75ee08bf..6ef7c72d305ee 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
```
}];
let arguments = !con(commonArgs, (ins
- AnyNon0RankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
IndexAttr:$gather_axis
));
let results = (outs
- AnyNon0RankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index 57d65e687ea35..1ddd1985389bc 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder);
// Get process linear index along the given grid axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ArrayRef<GridAxis> gridAxes = {});
// Get process linear index from a multi-index along the given grid axes .
TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder);
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes = {});
} // namespace shard
} // namespace mlir
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index b0831dc05abb1..1865914de9d84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
@@ -507,103 +508,147 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
}
}
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename CommOp>
+struct CommOpPattern : public OpConversionPattern<CommOp> {
+ using OpConversionPattern<CommOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SymbolTableCollection symbolTableCollection;
- auto grid = adaptor.getGrid();
- mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
- if (!gridOp)
- return op->emitError() << "No grid found for AllReduceOp";
- if (ShapedType::isDynamicShape(gridOp.getShape()))
- return op->emitError()
- << "Dynamic grid shape not supported in AllReduceOp";
-
- ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
- Value input = adaptor.getInput();
- auto inputShape = cast<ShapedType>(input.getType()).getShape();
+ MemRefType getMemrefType(ShapedType tensorType) const {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ }
+ Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+ auto itype = input.getType();
// If the source is a memref, cast it to a tensor.
- if (isa<RankedTensorType>(input.getType())) {
- auto memrefType = MemRefType::get(
- inputShape, cast<ShapedType>(input.getType()).getElementType());
+ if (isa<RankedTensorType>(itype)) {
+ auto memrefType = getMemrefType(cast<ShapedType>(itype));
input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+ } else {
+ assert(isa<MemRefType>(itype) &&
+ "expected input to be of MemRefType or TensorType");
}
- MemRefType inType = cast<MemRefType>(input.getType());
-
- // Get the actual shape to allocate the buffer.
- SmallVector<OpFoldResult> shape(inType.getRank());
- for (auto i = 0; i < inType.getRank(); ++i) {
- auto s = inputShape[i];
- if (ShapedType::isDynamic(s))
- shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
- else
- shape[i] = iBuilder.getIndexAttr(s);
- }
+ return input;
+ }
- // Allocate buffer and copy input to buffer.
- Value buffer = memref::AllocOp::create(
- iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
- linalg::CopyOp::create(iBuilder, input, buffer);
+ FailureOr<GridOp> checkGrid(CommOp op,
+ SymbolTableCollection &symbolTableCollection,
+ bool allowDynamic = false) const {
+ GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "Missing grid symbol.";
+ if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
+ return op->emitError() << "Dynamic grid shape not supported.";
+ return gridOp;
+ }
- // Get an MPI_Comm_split for the AllReduce operation.
+ Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
+ ImplicitLocOpBuilder &iBuilder) const {
+ // Get an MPI_Comm_split for a given grid and axes.
// The color is the linear index of the process in the grid along the
- // non-reduced axes. The key is the linear index of the process in the grid
- // along the reduced axes.
- SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
- iBuilder.getIndexType());
- SmallVector<Value> myMultiIndex =
- ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
- .getResult();
- Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
- SmallVector<Value> multiKey(myMultiIndex.size(), zero);
-
- auto redAxes = adaptor.getGridAxes();
- for (auto axis : redAxes) {
- multiKey[axis] = myMultiIndex[axis];
- myMultiIndex[axis] = zero;
+ // non-'grid-axes'. The key is the linear index of the process in the grid
+ // along the grid-axes.
+ SmallVector<GridAxis> otherAxes;
+ for (GridAxis i = 0; i < static_cast<GridAxis>(gridOp.getShape().size());
+ ++i) {
+ if (!llvm::is_contained(gridAxes, i))
+ otherAxes.emplace_back(i);
}
+ SmallVector<Type> indexResultTypes(otherAxes.size(),
+ iBuilder.getIndexType());
+
Value color =
- createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+ createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+
+ Value key =
+ createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
- auto commType = mpi::CommType::get(op->getContext());
+ auto commType = mpi::CommType::get(gridOp->getContext());
Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
- auto comm =
- mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
- .getNewcomm();
-
- Value buffer1d = buffer;
- // Collapse shape to 1d if needed
- if (inType.getRank() > 1) {
- ReassociationIndices reassociation(inType.getRank());
- std::iota(reassociation.begin(), reassociation.end(), 0);
- buffer1d = memref::CollapseShapeOp::create(
- iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
- }
+ return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+ .getNewcomm();
+ }
+};
+struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
+ using CommOpPattern::CommOpPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+
+ // Allocate buffer and copy input to buffer.
+ Value buffer = memref::AllocOp::create(iBuilder, outType);
+ linalg::CopyOp::create(iBuilder, input, buffer);
+ // Get the right communicator
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
// Create the MPI AllReduce operation.
- mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+ mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
getMPIReductionOp(adaptor.getReductionAttr()),
comm);
- // If the destination is a memref, cast it to a tensor
+ // If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
true);
-
rewriter.replaceOp(op, buffer);
return success();
}
};
+struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
+ using CommOpPattern::CommOpPattern;
+
+ LogicalResult
+ matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+
+ // Get the right communicator
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
+ // Allocate output buffer
+ Value output = memref::AllocOp::create(iBuilder, outType);
+ // Create the MPI AllGather operation.
+ mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+
+ // If the destination is a tensor, cast it to a tensor
+ if (isa<RankedTensorType>(op.getType()))
+ output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
+ true);
+ rewriter.replaceOp(op, output);
+ return success();
+ }
+};
+
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
@@ -895,8 +940,8 @@ struct ConvertShardToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
- ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
- ctxt);
+ ConvertAllGatherOp, ConvertAllReduceOp,
+ ConvertProcessLinearIndexOp>(typeConverter, ctxt);
SymbolTableCollection stc;
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
populateAllSliceOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 0ae2a9cc0318c..d0165595f9fb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
- gridOp.getSymName(), reductionGridAxes, builder);
+ builder, gridOp.getSymName(), reductionGridAxes);
Value zero = arith::ConstantIndexOp::create(builder, 0);
Value isLeadProcess = arith::CmpIOp::create(
builder, builder.getI1Type(), arith::CmpIPredicate::eq,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index b433b8b0be7b2..835bc443d4b2a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
}
TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder) {
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes) {
Operation::result_range processGroupShape =
GridShapeOp::create(builder, grid, gridAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
return cast<TypedValue<IndexType>>(res);
}
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder) {
+TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
+ StringRef grid,
+ ArrayRef<GridAxis> gridAxes) {
return createProcessLinearIndex(
- grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
- gridAxes, builder);
+ builder, grid,
+ ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+ gridAxes);
}
} // namespace mlir::shard
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index a0b6bfaf6fd3d..d4741102e4d3f 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
- // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
// CHECK: return [[v2]] : tensor<3x4xf32>
@@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_memref(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
// CHECK: return [[valloc]] : memref<3x4xf32>
return %0 : memref<3x4xf32>
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_new_type(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
// CHECK: return [[valloc]] : memref<3x4xf64>
return %0 : memref<3x4xf64>
}
+
+ // CHECK-LABEL: func @allgather_tensor
+ func.func @allgather_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+ // CHECK-SAME: -> tensor<3x20xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+ // CHECK: [[vc2_i32:%.*]]...
[truncated]
|
tkarna
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good on high level. Going forward we should ensure that bufferization works as expected.
| // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> | ||
| // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm | ||
| // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm | ||
| // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving forward we should have a mechanism to deallocate the output buffer if safe to do so.
| %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> { | ||
| // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 | ||
| // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 | ||
| // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add a read_only attribute if the access pattern is always the same.
| // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm | ||
| // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32> | ||
| // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32> | ||
| // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x20xf32> to tensor<3x20xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might need writable attribute to avoid copies later on.
- lowering `shard.allgather` to `mpi.allgather` - fixing lowering of `shard.allreduce` - minor refactoring
Reverts carried forward: * Local revert of llvm/llvm-project#169614 due to #22649 Other changes: * cast shard.all_gather result type to `ShapedType` due to its result changing types to be `AnyTypeOf<[AnyMemRef, AnyRankedTensor]>` in llvm/llvm-project#177202
- lowering `shard.allgather` to `mpi.allgather` - fixing lowering of `shard.allreduce` - minor refactoring
Reverts carried forward: * Local revert of llvm/llvm-project#169614 due to #22649 Other changes: * cast shard.all_gather result type to `ShapedType` due to its result changing types to be `AnyTypeOf<[AnyMemRef, AnyRankedTensor]>` in llvm/llvm-project#177202 Signed-off-by: Keshav Vinayak Jha <[email protected]>
- lowering `shard.allgather` to `mpi.allgather` - fixing lowering of `shard.allreduce` - minor refactoring
shard.allgathertompi.allgathershard.allreduce