Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-device computation in runtime #1716

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct TTIRToTTNNBackendPipelineOptions

// Option to enable/disable the workaround pass.
//
Option<bool> layouotWorkaroundsEnabled{
Option<bool> layoutWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)};

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
}];

let options = [
Option<"layouotWorkaroundsEnabled",
Option<"layoutWorkaroundsEnabled",
"ttnn-enable-layout-workaround-pass",
"bool", /*default=*/"true",
"TTNN Layout Workarounds Pass">,
Expand Down
12 changes: 12 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ enum BufferType: ushort {
Trace,
}

enum MeshShardDirection: uint32 {
FullToShardShape,
ShardToFullShape,
}

enum MeshShardType: uint32 {
Manual,
Replicate,
Maximal,
Devices,
}

// TODO (#620): Add other fields like core_ranges, shard orientation etc.
table ShardSpec {
shard_shape: [int64];
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ table MeshShardOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
shard_direction: uint32;
shard_type: uint32;
shard_direction: tt.target.MeshShardDirection;
shard_type: tt.target.MeshShardType;
shard_shape: [int64];
}

Expand Down
12 changes: 6 additions & 6 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,16 +1058,16 @@ class StableHLOToTTIRAllReduceOpConversionPattern
}
}

// Algorithm here is to search for the first non-one working dimension
// Algorithm: search for first non-one working dimension from back
auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
size_t dim = 0;
for (auto s : replicaGroupsShape) {
if (s != 1) {
size_t dim = replicaGroupsShape.size() - 1;
for (auto s = replicaGroupsShape.rbegin(); s != replicaGroupsShape.rend();
++s, --dim) {
if (*s != 1) {
break;
}
++dim;
}
if (dim > replicaGroupsShape.size()) {
if (dim < 0) {
// all one shape, then select the fastest dim
dim = replicaGroupsShape.size();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void createTTNNPipelineLoweringPasses(
void createTTNNPipelineWorkaroundPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
TTNNWorkaroundsOptions workaroundOptions{
options.layouotWorkaroundsEnabled,
options.layoutWorkaroundsEnabled,
options.decompositionWorkaroundsEnabled};
pm.addPass(createTTNNWorkarounds(workaroundOptions));
pm.addPass(mlir::createCanonicalizerPass());
Expand Down
28 changes: 15 additions & 13 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,9 @@ class TTNNLayoutDPSOperandsRewriter

LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
// To layout op is a special case, we don't want to rewrite it
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) {
// To layout and mesh_shard op are special cases not to rewrite them
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation()) ||
mlir::isa<ttir::MeshShardOp>(op.getOperation())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any comment on why MeshShardOp is a special one in this regard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TTNN mesh shard APIs are currently CPU only operations.. So, by enforcing tensors to be located in system memory, we can ensure (1) a tensor can be sharded into multi-device storage in cpu side, and (2) later tiled and transferred to device to individual devices.

return failure();
}

Expand Down Expand Up @@ -330,15 +331,15 @@ class TTNNLayoutDPSOperandsRewriter
}
};

// Updates the layout of the operands of a func::ReturnOp.
// The intent is to move the result to host.
class TTNNLayoutFuncReturnRewriter
: public OpRewritePattern<mlir::func::ReturnOp> {
// Updates the layout of the operands of the SrcOp such that
// the operands reside in host memory.
template <typename SrcOp>
class TTNNLayoutForceSystemMemoryRewriter : public OpRewritePattern<SrcOp> {
public:
TTNNLayoutFuncReturnRewriter(MLIRContext *ctx)
: OpRewritePattern<mlir::func::ReturnOp>(ctx) {}
TTNNLayoutForceSystemMemoryRewriter(MLIRContext *ctx)
: OpRewritePattern<SrcOp>(ctx) {}

LogicalResult matchAndRewrite(mlir::func::ReturnOp op,
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
bool modified = false;
for (OpOperand &operand : op->getOpOperands()) {
Expand All @@ -355,8 +356,6 @@ class TTNNLayoutFuncReturnRewriter
}
return modified ? success() : failure();
}

private:
};

class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
Expand Down Expand Up @@ -387,9 +386,12 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
// and rewrites its operands and result to have the correct layout
// with respect to operand constraints.
patterns.add<TTNNLayoutDPSOperandsRewriter>(&getContext());
// Takes func::Return op and sets layout which will
// Takes func::Return and ttir::MeshShard ops and set layout which will
// move it's operands to host
patterns.add<TTNNLayoutFuncReturnRewriter>(&getContext());
patterns.add<TTNNLayoutForceSystemMemoryRewriter<ttir::MeshShardOp>>(
&getContext());
patterns.add<TTNNLayoutForceSystemMemoryRewriter<mlir::func::ReturnOp>>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
GreedyRewriteConfig config = GreedyRewriteConfig();
config.useTopDownTraversal = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
runRewritePatterns(std::move(patterns),
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
}
if (layouotWorkaroundsEnabled) {
if (layoutWorkaroundsEnabled) {
RewritePatternSet patterns(&getContext());
patterns.add<TTNNOperandsWorkaroundsRewriter>(&getContext());

Expand Down
24 changes: 22 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,31 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) {
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
auto device = getOperandThroughDPSOps(op.getDevice());
const mlir::tt::MeshShardDirection shardDirection = op.getShardDirection();
const mlir::tt::MeshShardType shardType = op.getShardType();
llvm::ArrayRef<int64_t> shardShape = op.getShardShape().getShape();

::tt::target::MeshShardDirection meshShardDirection;
if (shardDirection == mlir::tt::MeshShardDirection::FullToShard) {
meshShardDirection = ::tt::target::MeshShardDirection::FullToShardShape;
} else if (shardDirection == mlir::tt::MeshShardDirection::ShardToFull) {
meshShardDirection = ::tt::target::MeshShardDirection::ShardToFullShape;
} else {
llvm_unreachable("unhandled mesh_shard direction");
}

::tt::target::MeshShardType meshShardType;
if (shardType == mlir::tt::MeshShardType::Replicate) {
meshShardType = ::tt::target::MeshShardType::Replicate;
} else if (shardType == mlir::tt::MeshShardType::Devices) {
meshShardType = ::tt::target::MeshShardType::Devices;
} else {
llvm_unreachable("unhandled mesh_shard type");
}

return ::tt::target::ttnn::CreateMeshShardOp(
*cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device),
static_cast<uint32_t>(op.getShardDirection()),
static_cast<uint32_t>(op.getShardType()),
meshShardDirection, meshShardType,
cache.fbb->CreateVector<int64_t>(shardShape));
}

Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/mesh_shard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand Down
16 changes: 12 additions & 4 deletions runtime/lib/ttnn/operations/ccl/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
// SPDX-License-Identifier: Apache-2.0

#include "operations/ccl/all_gather.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/operations/ccl/ccl_host_types.hpp"

namespace tt::runtime::ttnn::operations::ccl {
void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
int32_t dim = op->dim();
int32_t num_links = op->num_links();
int32_t gatherDim = op->dim();
int32_t numLinks = op->num_links();
LOG_ASSERT(
input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE,
"Input of all_gather must be MULTIDEVICE. id:", op->in()->global_id());
::tt::tt_metal::MemoryConfig outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
::ttnn::Tensor out =
::ttnn::all_gather(input, dim, num_links, outputMemoryConfig);
::ttnn::MeshDevice &meshDevice =
context.getSubMesh(op->device()->global_id());
::ttnn::Tensor out = ::ttnn::all_gather(
input, gatherDim, 1, meshDevice, numLinks, outputMemoryConfig,
std::nullopt, std::nullopt, ::ttnn::ccl::Topology::Linear);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::ccl
129 changes: 129 additions & 0 deletions runtime/lib/ttnn/operations/ccl/mesh_shard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "operations/ccl/mesh_shard.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/distributed/distributed_tensor.hpp"
#include "ttnn/tensor/xtensor/partition.hpp"

namespace tt::runtime::ttnn::operations::ccl {
void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
::ttnn::MeshDevice &meshDevice,
const ::tt::target::MeshShardType &shardType,
const std::vector<int64_t> &shardShape) {
if (shardType == ::tt::target::MeshShardType::Replicate) {
out = ::ttnn::distributed::distribute_tensor(
input, meshDevice,
*::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice));
} else {
LOG_ASSERT(
input.get_shape().rank() > 1,
"Sharding requires higher than 2 dimensional tensor. Tensor rank=",
input.get_shape().rank());
auto rowMesh = static_cast<size_t>(shardShape[0]);
auto colMesh = static_cast<size_t>(shardShape[1]);
int lastDim = input.get_shape().rank() - 1;
LOG_ASSERT((rowMesh * colMesh) > 1,
"Sharding requires higher than 1 mesh. shardShape ", rowMesh,
colMesh);

::ttnn::distributed::Shard2dConfig shard2dConfig;
// last tile replicate
if (colMesh == 1) {
if (rowMesh == meshDevice.num_rows()) {
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
.row_dim = (lastDim - 1), .col_dim = std::nullopt};
} else {
// transpose
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
.row_dim = std::nullopt, .col_dim = (lastDim - 1)};
}
} else {
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
.row_dim = (lastDim - 1), .col_dim = lastDim};
}

out = ::ttnn::distributed::distribute_tensor(
input, meshDevice,
*::ttnn::distributed::shard_tensor_to_2d_mesh_mapper(
meshDevice, meshDevice.shape(), shard2dConfig));
}
}

void ShardToFullShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
::ttnn::MeshDevice &meshDevice,
const ::tt::target::MeshShardType &shardType,
const std::vector<int64_t> &shardShape) {
std::vector<::ttnn::Tensor> input_tensors =
::ttnn::distributed::get_tensors_from_multi_device_storage(input);
if (shardType == ::tt::target::MeshShardType::Replicate) {
out = input_tensors[0];
} else {
auto rowMesh = static_cast<size_t>(shardShape[0]);
auto colMesh = static_cast<size_t>(shardShape[1]);
int lastDim = input.get_shape().rank() - 1;
if ((rowMesh * colMesh) ==
(meshDevice.num_rows() * meshDevice.num_cols())) {
// Full multi-device storage concatenation
if (shardShape[0] == 1 || shardShape[1] == 1) {
out = ::ttnn::distributed::aggregate_tensor(
input, *::ttnn::distributed::concat_mesh_to_tensor_composer(
(shardShape[1] == 1 ? (lastDim - 1) : lastDim)));
} else {
out = ::ttnn::distributed::aggregate_tensor(
input, *::ttnn::distributed::concat_2d_mesh_to_tensor_composer(
meshDevice, ::ttnn::distributed::Concat2dConfig{
.row_dim = static_cast<int>(lastDim - 1),
.col_dim = static_cast<int>(lastDim)}));
}
} else {
// Partial multi-device storage concatenation
// Current ttnn api does not support partial multi-device storage
// concatenation. Thus, xtensor APIs are being called directly from here.
std::vector<::ttnn::Tensor> target_tensors;
bool transpose = (rowMesh != meshDevice.num_rows());
size_t iteration = (transpose) ? colMesh : rowMesh;
size_t stride =
(transpose) ? meshDevice.num_rows() : meshDevice.num_cols();
for (size_t i = 0; i < iteration; ++i) {
target_tensors.push_back(input_tensors[i * stride]);
}
out = ::ttnn::experimental::xtensor::concat(target_tensors, lastDim - 1);
}
}
}

void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
const ::tt::target::MeshShardDirection shardDirection = op->shard_direction();
const ::tt::target::MeshShardType shardType = op->shard_type();
const auto *fbShardShape = op->shard_shape();
std::vector<int64_t> shardShape(fbShardShape->begin(), fbShardShape->end());

if (shardDirection != ::tt::target::MeshShardDirection::FullToShardShape &&
shardDirection != ::tt::target::MeshShardDirection::ShardToFullShape) {
throw std::runtime_error("Unsupported shard direction");
}

if (shardType != ::tt::target::MeshShardType::Replicate &&
shardType != ::tt::target::MeshShardType::Devices) {
throw std::runtime_error("Unsupported shard type");
}

::ttnn::MeshDevice &meshDevice =
context.getSubMesh(op->device()->global_id());

::ttnn::Tensor out;
if (shardDirection == ::tt::target::MeshShardDirection::FullToShardShape) {
FullToShardShape(input, out, meshDevice, shardType, shardShape);
} else {
ShardToFullShape(input, out, meshDevice, shardType, shardShape);
}
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::ccl
15 changes: 15 additions & 0 deletions runtime/lib/ttnn/operations/ccl/mesh_shard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CCL_MESH_SHARD_H
#define RUNTIME_LIB_TTNN_OPERATIONS_CCL_MESH_SHARD_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ccl {
void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context);
} // namespace tt::runtime::ttnn::operations::ccl

#endif
Loading
Loading