Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
IndexAttr:$gather_axis
));
let results = (outs
AnyNon0RankedTensor:$result
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder);

// Get process linear index along the given grid axes.
TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder);
TypedValue<IndexType>
createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
ArrayRef<GridAxis> gridAxes = {});
// Get process linear index from a multi-index along the given grid axes .
TypedValue<IndexType>
createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder);
createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
ValueRange processInGroupMultiIndex,
ArrayRef<GridAxis> gridAxes = {});

} // namespace shard
} // namespace mlir
Expand Down
190 changes: 120 additions & 70 deletions mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
Expand Down Expand Up @@ -507,103 +508,152 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
}
}

struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
using OpConversionPattern::OpConversionPattern;
template <typename CommOp>
struct CommOpPattern : public OpConversionPattern<CommOp> {
using OpConversionPattern<CommOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SymbolTableCollection symbolTableCollection;
auto grid = adaptor.getGrid();
mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
if (!gridOp)
return op->emitError() << "No grid found for AllReduceOp";
if (ShapedType::isDynamicShape(gridOp.getShape()))
return op->emitError()
<< "Dynamic grid shape not supported in AllReduceOp";

ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = adaptor.getInput();
auto inputShape = cast<ShapedType>(input.getType()).getShape();
MemRefType getMemrefType(ShapedType tensorType) const {
return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
}

Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
auto itype = input.getType();
// If the source is a memref, cast it to a tensor.
if (isa<RankedTensorType>(input.getType())) {
auto memrefType = MemRefType::get(
inputShape, cast<ShapedType>(input.getType()).getElementType());
if (isa<RankedTensorType>(itype)) {
auto memrefType = getMemrefType(cast<ShapedType>(itype));
input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
} else {
assert(isa<MemRefType>(itype) &&
"expected input to be of MemRefType or TensorType");
}
MemRefType inType = cast<MemRefType>(input.getType());
return input;
}

// Get the actual shape to allocate the buffer.
SmallVector<OpFoldResult> shape(inType.getRank());
for (auto i = 0; i < inType.getRank(); ++i) {
auto s = inputShape[i];
if (ShapedType::isDynamic(s))
shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
else
shape[i] = iBuilder.getIndexAttr(s);
}
FailureOr<GridOp> checkGrid(CommOp op,
SymbolTableCollection &symbolTableCollection,
bool allowDynamic = false) const {
GridOp gridOp = getGrid(op, symbolTableCollection);
if (!gridOp)
return op->emitError() << "Missing grid symbol.";
if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
return op->emitError() << "Dynamic grid shape not supported.";
return gridOp;
}

// Allocate buffer and copy input to buffer.
Value buffer = memref::AllocOp::create(
iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
linalg::CopyOp::create(iBuilder, input, buffer);
// Get an MPI_Comm_split for a given grid and axes.
// The color is the linear index of the process in the grid along the
// non-'grid-axes'. The key is the linear index of the process in the grid
// along the grid-axes.
Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
ImplicitLocOpBuilder &iBuilder) const {
size_t gridDims = gridOp.getShape().size();
auto commType = mpi::CommType::get(gridOp->getContext());
Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);

// Get an MPI_Comm_split for the AllReduce operation.
// The color is the linear index of the process in the grid along the
// non-reduced axes. The key is the linear index of the process in the grid
// along the reduced axes.
SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
iBuilder.getIndexType());
SmallVector<Value> myMultiIndex =
ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
.getResult();
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
if (gridAxes.empty() || gridAxes.size() >= gridDims) {
return commWorld;
}

auto redAxes = adaptor.getGridAxes();
for (auto axis : redAxes) {
multiKey[axis] = myMultiIndex[axis];
myMultiIndex[axis] = zero;
SmallVector<GridAxis> otherAxes;
for (GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
if (!llvm::is_contained(gridAxes, i))
otherAxes.emplace_back(i);
}

SmallVector<Type> indexResultTypes(otherAxes.size(),
iBuilder.getIndexType());

Value color =
createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);

Value key =
createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);

// Finally split the communicator
auto commType = mpi::CommType::get(op->getContext());
Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
auto comm =
mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
.getNewcomm();

Value buffer1d = buffer;
// Collapse shape to 1d if needed
if (inType.getRank() > 1) {
ReassociationIndices reassociation(inType.getRank());
std::iota(reassociation.begin(), reassociation.end(), 0);
buffer1d = memref::CollapseShapeOp::create(
iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
}
return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
.getNewcomm();
}
};

struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
using CommOpPattern::CommOpPattern;

LogicalResult
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SymbolTableCollection symbolTableCollection;
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
if (failed(gridOp))
return failure();
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = getAsMemref(adaptor.getInput(), iBuilder);
MemRefType inType = cast<MemRefType>(input.getType());
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");

// Allocate buffer and copy input to buffer.
Value buffer = memref::AllocOp::create(iBuilder, outType);
linalg::CopyOp::create(iBuilder, input, buffer);
// Get the right communicator
Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
// Create the MPI AllReduce operation.
mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
getMPIReductionOp(adaptor.getReductionAttr()),
comm);

// If the destination is a memref, cast it to a tensor
// If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
true);

rewriter.replaceOp(op, buffer);
return success();
}
};

struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
using CommOpPattern::CommOpPattern;

LogicalResult
matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SymbolTableCollection symbolTableCollection;
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
if (failed(gridOp))
return failure();
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = getAsMemref(adaptor.getInput(), iBuilder);
MemRefType inType = cast<MemRefType>(input.getType());
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");

// Get the right communicator
Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
// Allocate output buffer
Value output = memref::AllocOp::create(iBuilder, outType);
// Create the MPI AllGather operation.
mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);

// If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
true);
rewriter.replaceOp(op, output);
return success();
}
};

struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -895,8 +945,8 @@ struct ConvertShardToMPIPass

patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
ctxt);
ConvertAllGatherOp, ConvertAllReduceOp,
ConvertProcessLinearIndexOp>(typeConverter, ctxt);
SymbolTableCollection stc;
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
populateAllSliceOpLoweringPatterns(patterns, stc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
gridOp.getSymName(), reductionGridAxes, builder);
builder, gridOp.getSymName(), reductionGridAxes);
Value zero = arith::ConstantIndexOp::create(builder, 0);
Value isLeadProcess = arith::CmpIOp::create(
builder, builder.getI1Type(), arith::CmpIPredicate::eq,
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
}

TypedValue<IndexType>
createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder) {
createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
ValueRange processInGroupMultiIndex,
ArrayRef<GridAxis> gridAxes) {
Operation::result_range processGroupShape =
GridShapeOp::create(builder, grid, gridAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
Expand All @@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
return cast<TypedValue<IndexType>>(res);
}

TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder) {
TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
StringRef grid,
ArrayRef<GridAxis> gridAxes) {
return createProcessLinearIndex(
grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
gridAxes, builder);
builder, grid,
ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
gridAxes);
}
} // namespace mlir::shard
Loading