Skip to content

Commit 2a566ed

Browse files
committed
Enable multi-device computation in runtime
* Allow ttnn runtime operations including reduce_scatter, mesh_shard, and all_gather * Force mesh_shard ops to use system memory because they are host-side operations * Use strongly-typed sharding options for mesh_shard ops * Add Silicon multi-device test cases * Fix bug in determining axis of all_reduce when converting from stableHLO * Fix typo in ttnn workaround pass
1 parent dca29f4 commit 2a566ed

File tree

25 files changed

+383
-55
lines changed

25 files changed

+383
-55
lines changed

include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ struct TTIRToTTNNBackendPipelineOptions
128128

129129
// Option to enable/disable the workaround pass.
130130
//
131-
Option<bool> layouotWorkaroundsEnabled{
131+
Option<bool> layoutWorkaroundsEnabled{
132132
*this, "enable-layout-workaround-pass",
133133
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)};
134134

include/ttmlir/Dialect/TTNN/Transforms/Passes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
3636
}];
3737

3838
let options = [
39-
Option<"layouotWorkaroundsEnabled",
39+
Option<"layoutWorkaroundsEnabled",
4040
"ttnn-enable-layout-workaround-pass",
4141
"bool", /*default=*/"true",
4242
"TTNN Layout Workarounds Pass">,

include/ttmlir/Target/Common/types.fbs

+12
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ enum BufferType: ushort {
7474
Trace,
7575
}
7676

77+
enum MeshShardDirection: uint32 {
78+
FullToShardShape,
79+
ShardToFullShape,
80+
}
81+
82+
enum MeshShardType: uint32 {
83+
Manual,
84+
Replicate,
85+
Maximal,
86+
Devices,
87+
}
88+
7789
// TODO (#620): Add other fields like core_ranges, shard orientation etc.
7890
table ShardSpec {
7991
shard_shape: [int64];

include/ttmlir/Target/TTNN/program.fbs

+2-2
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ table MeshShardOp {
316316
in: tt.target.TensorRef;
317317
out: tt.target.TensorRef;
318318
device: tt.target.DeviceRef;
319-
shard_direction: uint32;
320-
shard_type: uint32;
319+
shard_direction: tt.target.MeshShardDirection;
320+
shard_type: tt.target.MeshShardType;
321321
shard_shape: [int64];
322322
}
323323

lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1094,16 +1094,16 @@ class StableHLOToTTIRAllReduceOpConversionPattern
10941094
}
10951095
}
10961096

1097-
// Algorithm here is to search for the first non-one working dimension
1097+
// Algorithm: search for first non-one working dimension from back
10981098
auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
1099-
size_t dim = 0;
1100-
for (auto s : replicaGroupsShape) {
1101-
if (s != 1) {
1099+
size_t dim = replicaGroupsShape.size() - 1;
1100+
for (auto s = replicaGroupsShape.rbegin(); s != replicaGroupsShape.rend();
1101+
++s, --dim) {
1102+
if (*s != 1) {
11021103
break;
11031104
}
1104-
++dim;
11051105
}
1106-
if (dim > replicaGroupsShape.size()) {
1106+
if (dim < 0) {
11071107
// all one shape, then select the fastest dim
11081108
dim = replicaGroupsShape.size();
11091109
}

lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void createTTNNPipelineLoweringPasses(
6767
void createTTNNPipelineWorkaroundPass(
6868
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
6969
TTNNWorkaroundsOptions workaroundOptions{
70-
options.layouotWorkaroundsEnabled,
70+
options.layoutWorkaroundsEnabled,
7171
options.decompositionWorkaroundsEnabled};
7272
pm.addPass(createTTNNWorkarounds(workaroundOptions));
7373
pm.addPass(mlir::createCanonicalizerPass());

lib/Dialect/TTNN/Transforms/TTNNLayout.cpp

+13-11
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ class TTNNLayoutDPSOperandsRewriter
275275
LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
276276
PatternRewriter &rewriter) const final {
277277
// To layout op is a special case, we don't want to rewrite it
278-
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) {
278+
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation()) ||
279+
mlir::isa<ttir::MeshShardOp>(op.getOperation())) {
279280
return failure();
280281
}
281282

@@ -330,15 +331,15 @@ class TTNNLayoutDPSOperandsRewriter
330331
}
331332
};
332333

333-
// Updates the layout of the operands of a func::ReturnOp.
334-
// The intent is to move the result to host.
335-
class TTNNLayoutFuncReturnRewriter
336-
: public OpRewritePattern<mlir::func::ReturnOp> {
334+
// Updates the layout of the operands of the SrcOp such that
335+
// the operands reside in host memory.
336+
template <typename SrcOp>
337+
class TTNNLayoutForceSystemMemory : public OpRewritePattern<SrcOp> {
337338
public:
338-
TTNNLayoutFuncReturnRewriter(MLIRContext *ctx)
339-
: OpRewritePattern<mlir::func::ReturnOp>(ctx) {}
339+
TTNNLayoutForceSystemMemory(MLIRContext *ctx)
340+
: OpRewritePattern<SrcOp>(ctx) {}
340341

341-
LogicalResult matchAndRewrite(mlir::func::ReturnOp op,
342+
LogicalResult matchAndRewrite(SrcOp op,
342343
PatternRewriter &rewriter) const final {
343344
bool modified = false;
344345
for (OpOperand &operand : op->getOpOperands()) {
@@ -355,8 +356,6 @@ class TTNNLayoutFuncReturnRewriter
355356
}
356357
return modified ? success() : failure();
357358
}
358-
359-
private:
360359
};
361360

362361
class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
@@ -389,7 +388,10 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
389388
patterns.add<TTNNLayoutDPSOperandsRewriter>(&getContext());
390389
// Takes func::Return op and sets layout which will
391390
// move it's operands to host
392-
patterns.add<TTNNLayoutFuncReturnRewriter>(&getContext());
391+
patterns.add<TTNNLayoutForceSystemMemory<ttir::MeshShardOp>>(
392+
&getContext());
393+
patterns.add<TTNNLayoutForceSystemMemory<mlir::func::ReturnOp>>(
394+
&getContext());
393395
FrozenRewritePatternSet patternSet(std::move(patterns));
394396
GreedyRewriteConfig config = GreedyRewriteConfig();
395397
config.useTopDownTraversal = true;

lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
435435
runRewritePatterns(std::move(patterns),
436436
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
437437
}
438-
if (layouotWorkaroundsEnabled) {
438+
if (layoutWorkaroundsEnabled) {
439439
RewritePatternSet patterns(&getContext());
440440
patterns.add<TTNNOperandsWorkaroundsRewriter>(&getContext());
441441

lib/Target/TTNN/TTNNToFlatbuffer.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -499,11 +499,31 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) {
499499
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
500500
kHostAllocatedAddress, kHostAllocatedSize);
501501
auto device = getOperandThroughDPSOps(op.getDevice());
502+
const mlir::tt::MeshShardDirection shardDirection = op.getShardDirection();
503+
const mlir::tt::MeshShardType shardType = op.getShardType();
502504
llvm::ArrayRef<int64_t> shardShape = op.getShardShape().getShape();
505+
506+
::tt::target::MeshShardDirection meshShardDirection;
507+
if (shardDirection == mlir::tt::MeshShardDirection::FullToShard) {
508+
meshShardDirection = ::tt::target::MeshShardDirection::FullToShardShape;
509+
} else if (shardDirection == mlir::tt::MeshShardDirection::ShardToFull) {
510+
meshShardDirection = ::tt::target::MeshShardDirection::ShardToFullShape;
511+
} else {
512+
llvm_unreachable("unhandled mesh_shard direction");
513+
}
514+
515+
::tt::target::MeshShardType meshShardType;
516+
if (shardType == mlir::tt::MeshShardType::Replicate) {
517+
meshShardType = ::tt::target::MeshShardType::Replicate;
518+
} else if (shardType == mlir::tt::MeshShardType::Devices) {
519+
meshShardType = ::tt::target::MeshShardType::Devices;
520+
} else {
521+
llvm_unreachable("unhandled mesh_shard type");
522+
}
523+
503524
return ::tt::target::ttnn::CreateMeshShardOp(
504525
*cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device),
505-
static_cast<uint32_t>(op.getShardDirection()),
506-
static_cast<uint32_t>(op.getShardType()),
526+
meshShardDirection, meshShardType,
507527
cache.fbb->CreateVector<int64_t>(shardShape));
508528
}
509529

runtime/lib/ttnn/operations/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ set(TTNN_OPS_SRCS
44
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
55
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
66
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter.cpp
8+
${CMAKE_CURRENT_SOURCE_DIR}/ccl/mesh_shard.cpp
79
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
810
${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp
911
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp

runtime/lib/ttnn/operations/ccl/all_gather.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,30 @@
33
// SPDX-License-Identifier: Apache-2.0
44

55
#include "operations/ccl/all_gather.h"
6+
#include "tt/runtime/detail/logger.h"
67
#include "tt/runtime/detail/ttnn.h"
78
#include "tt/runtime/ttnn/operations/utils.h"
89
#include "tt/runtime/ttnn/utils.h"
10+
#include "ttnn/operations/ccl/ccl_host_types.hpp"
911

1012
namespace tt::runtime::ttnn::operations::ccl {
13+
1114
void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) {
1215
ProgramTensorPool &tensorPool = context.getTensorPool();
1316
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
14-
int32_t dim = op->dim();
15-
int32_t num_links = op->num_links();
17+
int32_t gatherDim = op->dim();
18+
int32_t numLinks = op->num_links();
19+
LOG_ASSERT(
20+
input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE,
21+
"Input of all_gather must be MULTIDEVICE. id:", op->in()->global_id());
1622
::tt::tt_metal::MemoryConfig outputMemoryConfig =
1723
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
18-
::ttnn::Tensor out =
19-
::ttnn::all_gather(input, dim, num_links, outputMemoryConfig);
24+
::ttnn::MeshDevice &meshDevice =
25+
context.getSubMesh(op->device()->global_id());
26+
::ttnn::Tensor out = ::ttnn::all_gather(
27+
input, gatherDim, 1, meshDevice, numLinks, outputMemoryConfig,
28+
std::nullopt, std::nullopt, ::ttnn::ccl::Topology::Linear);
2029
tensorPool.insert_or_assign(op->out()->global_id(), out);
2130
}
31+
2232
} // namespace tt::runtime::ttnn::operations::ccl
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include "mesh_shard.h"
6+
#include "tt/runtime/detail/logger.h"
7+
#include "tt/runtime/detail/ttnn.h"
8+
#include "tt/runtime/ttnn/operations/utils.h"
9+
#include "tt/runtime/ttnn/utils.h"
10+
#include "ttnn/distributed/distributed_tensor.hpp"
11+
#include "ttnn/tensor/xtensor/partition.hpp"
12+
13+
namespace tt::runtime::ttnn::operations::ccl {
14+
15+
void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
16+
::ttnn::MeshDevice &meshDevice,
17+
const ::tt::target::MeshShardType &shardType,
18+
const std::vector<int64_t> &shardShape) {
19+
if (shardType == ::tt::target::MeshShardType::Replicate) {
20+
out = ::ttnn::distributed::distribute_tensor(
21+
input, meshDevice,
22+
*::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice));
23+
} else {
24+
LOG_ASSERT(
25+
input.get_shape().rank() > 1,
26+
"Sharding requires higher than 2 dimensional tensor. Tensor rank=",
27+
input.get_shape().rank());
28+
auto rowMesh = static_cast<size_t>(shardShape[0]);
29+
auto colMesh = static_cast<size_t>(shardShape[1]);
30+
int lastDim = input.get_shape().rank() - 1;
31+
LOG_ASSERT((rowMesh * colMesh) > 1,
32+
"Sharding requires higher than 1 mesh. shardShape ", rowMesh,
33+
colMesh);
34+
35+
::ttnn::distributed::Shard2dConfig shard2dConfig;
36+
// last tile replicate
37+
if (colMesh == 1) {
38+
if (rowMesh == meshDevice.num_rows()) {
39+
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
40+
.row_dim = (lastDim - 1), .col_dim = std::nullopt};
41+
} else {
42+
// transpose
43+
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
44+
.row_dim = std::nullopt, .col_dim = (lastDim - 1)};
45+
}
46+
} else {
47+
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
48+
.row_dim = (lastDim - 1), .col_dim = lastDim};
49+
}
50+
51+
out = ::ttnn::distributed::distribute_tensor(
52+
input, meshDevice,
53+
*::ttnn::distributed::shard_tensor_to_2d_mesh_mapper(
54+
meshDevice, meshDevice.shape(), shard2dConfig));
55+
}
56+
}
57+
58+
void ShardToFullShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
59+
::ttnn::MeshDevice &meshDevice,
60+
const ::tt::target::MeshShardType &shardType,
61+
const std::vector<int64_t> &shardShape) {
62+
std::vector<::ttnn::Tensor> input_tensors =
63+
::ttnn::distributed::get_tensors_from_multi_device_storage(input);
64+
if (shardType == ::tt::target::MeshShardType::Replicate) {
65+
out = input_tensors[0];
66+
} else {
67+
auto rowMesh = static_cast<size_t>(shardShape[0]);
68+
auto colMesh = static_cast<size_t>(shardShape[1]);
69+
int lastDim = input.get_shape().rank() - 1;
70+
if ((rowMesh * colMesh) ==
71+
(meshDevice.num_rows() * meshDevice.num_cols())) {
72+
// Full multi-device storage concatenation
73+
if (shardShape[0] == 1 || shardShape[1] == 1) {
74+
out = ::ttnn::distributed::aggregate_tensor(
75+
input, *::ttnn::distributed::concat_mesh_to_tensor_composer(
76+
(shardShape[1] == 1 ? (lastDim - 1) : lastDim)));
77+
} else {
78+
out = ::ttnn::distributed::aggregate_tensor(
79+
input, *::ttnn::distributed::concat_2d_mesh_to_tensor_composer(
80+
meshDevice, ::ttnn::distributed::Concat2dConfig{
81+
.row_dim = static_cast<int>(lastDim - 1),
82+
.col_dim = static_cast<int>(lastDim)}));
83+
}
84+
} else {
85+
// Partial multi-device storage concatenation
86+
// Current ttnn api does not support partial multi-device storage
87+
// concatenation. Thus, xtensor APIs are being called directly from here.
88+
std::vector<::ttnn::Tensor> target_tensors;
89+
bool transpose = (rowMesh != meshDevice.num_rows());
90+
size_t iteration = (transpose) ? colMesh : rowMesh;
91+
size_t stride =
92+
(transpose) ? meshDevice.num_rows() : meshDevice.num_cols();
93+
for (size_t i = 0; i < iteration; ++i) {
94+
target_tensors.push_back(input_tensors[i * stride]);
95+
}
96+
out = ::ttnn::experimental::xtensor::concat(target_tensors, lastDim - 1);
97+
}
98+
}
99+
}
100+
101+
void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) {
102+
ProgramTensorPool &tensorPool = context.getTensorPool();
103+
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
104+
const ::tt::target::MeshShardDirection shardDirection = op->shard_direction();
105+
const ::tt::target::MeshShardType shardType = op->shard_type();
106+
const auto *fbShardShape = op->shard_shape();
107+
std::vector<int64_t> shardShape(fbShardShape->begin(), fbShardShape->end());
108+
109+
if (shardDirection != ::tt::target::MeshShardDirection::FullToShardShape &&
110+
shardDirection != ::tt::target::MeshShardDirection::ShardToFullShape) {
111+
throw std::runtime_error("Unsupported shard direction");
112+
}
113+
114+
if (shardType != ::tt::target::MeshShardType::Replicate &&
115+
shardType != ::tt::target::MeshShardType::Devices) {
116+
throw std::runtime_error("Unsupported shard type");
117+
}
118+
119+
::ttnn::MeshDevice &meshDevice =
120+
context.getSubMesh(op->device()->global_id());
121+
122+
::ttnn::Tensor out;
123+
if (shardDirection == ::tt::target::MeshShardDirection::FullToShardShape) {
124+
FullToShardShape(input, out, meshDevice, shardType, shardShape);
125+
} else {
126+
ShardToFullShape(input, out, meshDevice, shardType, shardShape);
127+
}
128+
tensorPool.insert_or_assign(op->out()->global_id(), out);
129+
}
130+
131+
} // namespace tt::runtime::ttnn::operations::ccl
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#ifndef TTNN_RUNTIME_MESH_SHARD_H
6+
#define TTNN_RUNTIME_MESH_SHARD_H
7+
8+
#include "tt/runtime/ttnn/types.h"
9+
#include "ttmlir/Target/TTNN/program_generated.h"
10+
11+
namespace tt::runtime::ttnn::operations::ccl {
12+
void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context);
13+
} // namespace tt::runtime::ttnn::operations::ccl
14+
15+
#endif

0 commit comments

Comments
 (0)