Skip to content

Commit

Permalink
Enable multi-device computation in runtime
Browse files Browse the repository at this point in the history
* Allow ttnn runtime operations including reduce_scatter, mesh_shard,
  and all_gather

* Force mesh_shard ops to use system memory because they are host-side
  operations

* Use strongly-typed sharding options for mesh_shard ops

* Add Silicon multi-device test cases

* Fix bug in determining axis of all_reduce when converting from
  stableHLO

* Fix typo in ttnn workaround pass
  • Loading branch information
wooseokTT committed Jan 10, 2025
1 parent 2fcd37a commit 2fd39a1
Show file tree
Hide file tree
Showing 25 changed files with 383 additions and 55 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct TTIRToTTNNBackendPipelineOptions

// Option to enable/disable the workaround pass.
//
Option<bool> layouotWorkaroundsEnabled{
Option<bool> layoutWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)};

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
}];

let options = [
Option<"layouotWorkaroundsEnabled",
Option<"layoutWorkaroundsEnabled",
"ttnn-enable-layout-workaround-pass",
"bool", /*default=*/"true",
"TTNN Layout Workarounds Pass">,
Expand Down
12 changes: 12 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ enum BufferType: ushort {
Trace,
}

enum MeshShardDirection: uint32 {
FullToShardShape,
ShardToFullShape,
}

enum MeshShardType: uint32 {
Manual,
Replicate,
Maximal,
Devices,
}

// TODO (#620): Add other fields like core_ranges, shard orientation etc.
table ShardSpec {
shard_shape: [int64];
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ table MeshShardOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
shard_direction: uint32;
shard_type: uint32;
shard_direction: tt.target.MeshShardDirection;
shard_type: tt.target.MeshShardType;
shard_shape: [int64];
}

Expand Down
12 changes: 6 additions & 6 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,16 +1058,16 @@ class StableHLOToTTIRAllReduceOpConversionPattern
}
}

// Algorithm here is to search for the first non-one working dimension
// Algorithm: search for first non-one working dimension from back
auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
size_t dim = 0;
for (auto s : replicaGroupsShape) {
if (s != 1) {
size_t dim = replicaGroupsShape.size() - 1;
for (auto s = replicaGroupsShape.rbegin(); s != replicaGroupsShape.rend();
++s, --dim) {
if (*s != 1) {
break;
}
++dim;
}
if (dim > replicaGroupsShape.size()) {
if (dim < 0) {
// all one shape, then select the fastest dim
dim = replicaGroupsShape.size();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void createTTNNPipelineLoweringPasses(
void createTTNNPipelineWorkaroundPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
TTNNWorkaroundsOptions workaroundOptions{
options.layouotWorkaroundsEnabled,
options.layoutWorkaroundsEnabled,
options.decompositionWorkaroundsEnabled};
pm.addPass(createTTNNWorkarounds(workaroundOptions));
pm.addPass(mlir::createCanonicalizerPass());
Expand Down
24 changes: 13 additions & 11 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ class TTNNLayoutDPSOperandsRewriter
LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
// To layout op is a special case, we don't want to rewrite it
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) {
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation()) ||
mlir::isa<ttir::MeshShardOp>(op.getOperation())) {
return failure();
}

Expand Down Expand Up @@ -330,15 +331,15 @@ class TTNNLayoutDPSOperandsRewriter
}
};

// Updates the layout of the operands of a func::ReturnOp.
// The intent is to move the result to host.
class TTNNLayoutFuncReturnRewriter
: public OpRewritePattern<mlir::func::ReturnOp> {
// Updates the layout of the operands of the SrcOp such that
// the operands reside in host memory.
template <typename SrcOp>
class TTNNLayoutForceSystemMemory : public OpRewritePattern<SrcOp> {
public:
TTNNLayoutFuncReturnRewriter(MLIRContext *ctx)
: OpRewritePattern<mlir::func::ReturnOp>(ctx) {}
TTNNLayoutForceSystemMemory(MLIRContext *ctx)
: OpRewritePattern<SrcOp>(ctx) {}

LogicalResult matchAndRewrite(mlir::func::ReturnOp op,
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
bool modified = false;
for (OpOperand &operand : op->getOpOperands()) {
Expand All @@ -355,8 +356,6 @@ class TTNNLayoutFuncReturnRewriter
}
return modified ? success() : failure();
}

private:
};

class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
Expand Down Expand Up @@ -389,7 +388,10 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
patterns.add<TTNNLayoutDPSOperandsRewriter>(&getContext());
// Takes func::Return op and sets layout which will
// move it's operands to host
patterns.add<TTNNLayoutFuncReturnRewriter>(&getContext());
patterns.add<TTNNLayoutForceSystemMemory<ttir::MeshShardOp>>(
&getContext());
patterns.add<TTNNLayoutForceSystemMemory<mlir::func::ReturnOp>>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
GreedyRewriteConfig config = GreedyRewriteConfig();
config.useTopDownTraversal = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
runRewritePatterns(std::move(patterns),
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
}
if (layouotWorkaroundsEnabled) {
if (layoutWorkaroundsEnabled) {
RewritePatternSet patterns(&getContext());
patterns.add<TTNNOperandsWorkaroundsRewriter>(&getContext());

Expand Down
24 changes: 22 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,31 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) {
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
auto device = getOperandThroughDPSOps(op.getDevice());
const mlir::tt::MeshShardDirection shardDirection = op.getShardDirection();
const mlir::tt::MeshShardType shardType = op.getShardType();
llvm::ArrayRef<int64_t> shardShape = op.getShardShape().getShape();

::tt::target::MeshShardDirection meshShardDirection;
if (shardDirection == mlir::tt::MeshShardDirection::FullToShard) {
meshShardDirection = ::tt::target::MeshShardDirection::FullToShardShape;
} else if (shardDirection == mlir::tt::MeshShardDirection::ShardToFull) {
meshShardDirection = ::tt::target::MeshShardDirection::ShardToFullShape;
} else {
llvm_unreachable("unhandled mesh_shard direction");
}

::tt::target::MeshShardType meshShardType;
if (shardType == mlir::tt::MeshShardType::Replicate) {
meshShardType = ::tt::target::MeshShardType::Replicate;
} else if (shardType == mlir::tt::MeshShardType::Devices) {
meshShardType = ::tt::target::MeshShardType::Devices;
} else {
llvm_unreachable("unhandled mesh_shard type");
}

return ::tt::target::ttnn::CreateMeshShardOp(
*cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device),
static_cast<uint32_t>(op.getShardDirection()),
static_cast<uint32_t>(op.getShardType()),
meshShardDirection, meshShardType,
cache.fbb->CreateVector<int64_t>(shardShape));
}

Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/mesh_shard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand Down
18 changes: 14 additions & 4 deletions runtime/lib/ttnn/operations/ccl/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
// SPDX-License-Identifier: Apache-2.0

#include "operations/ccl/all_gather.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/operations/ccl/ccl_host_types.hpp"

namespace tt::runtime::ttnn::operations::ccl {

void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
int32_t dim = op->dim();
int32_t num_links = op->num_links();
int32_t gatherDim = op->dim();
int32_t numLinks = op->num_links();
LOG_ASSERT(
input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE,
"Input of all_gather must be MULTIDEVICE. id:", op->in()->global_id());
::tt::tt_metal::MemoryConfig outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
::ttnn::Tensor out =
::ttnn::all_gather(input, dim, num_links, outputMemoryConfig);
::ttnn::MeshDevice &meshDevice =
context.getSubMesh(op->device()->global_id());
::ttnn::Tensor out = ::ttnn::all_gather(
input, gatherDim, 1, meshDevice, numLinks, outputMemoryConfig,
std::nullopt, std::nullopt, ::ttnn::ccl::Topology::Linear);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

} // namespace tt::runtime::ttnn::operations::ccl
131 changes: 131 additions & 0 deletions runtime/lib/ttnn/operations/ccl/mesh_shard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "mesh_shard.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/distributed/distributed_tensor.hpp"
#include "ttnn/tensor/xtensor/partition.hpp"

namespace tt::runtime::ttnn::operations::ccl {

void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
::ttnn::MeshDevice &meshDevice,
const ::tt::target::MeshShardType &shardType,
const std::vector<int64_t> &shardShape) {
if (shardType == ::tt::target::MeshShardType::Replicate) {
out = ::ttnn::distributed::distribute_tensor(
input, meshDevice,
*::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice));
} else {
LOG_ASSERT(
input.get_shape().rank() > 1,
"Sharding requires higher than 2 dimensional tensor. Tensor rank=",
input.get_shape().rank());
auto rowMesh = static_cast<size_t>(shardShape[0]);
auto colMesh = static_cast<size_t>(shardShape[1]);
int lastDim = input.get_shape().rank() - 1;
LOG_ASSERT((rowMesh * colMesh) > 1,
"Sharding requires higher than 1 mesh. shardShape ", rowMesh,
colMesh);

::ttnn::distributed::Shard2dConfig shard2dConfig;
// last tile replicate
if (colMesh == 1) {
if (rowMesh == meshDevice.num_rows()) {
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
.row_dim = (lastDim - 1), .col_dim = std::nullopt};
} else {
// transpose
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
.row_dim = std::nullopt, .col_dim = (lastDim - 1)};
}
} else {
shard2dConfig = ::ttnn::distributed::Shard2dConfig{
.row_dim = (lastDim - 1), .col_dim = lastDim};
}

out = ::ttnn::distributed::distribute_tensor(
input, meshDevice,
*::ttnn::distributed::shard_tensor_to_2d_mesh_mapper(
meshDevice, meshDevice.shape(), shard2dConfig));
}
}

void ShardToFullShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out,
::ttnn::MeshDevice &meshDevice,
const ::tt::target::MeshShardType &shardType,
const std::vector<int64_t> &shardShape) {
std::vector<::ttnn::Tensor> input_tensors =
::ttnn::distributed::get_tensors_from_multi_device_storage(input);
if (shardType == ::tt::target::MeshShardType::Replicate) {
out = input_tensors[0];
} else {
auto rowMesh = static_cast<size_t>(shardShape[0]);
auto colMesh = static_cast<size_t>(shardShape[1]);
int lastDim = input.get_shape().rank() - 1;
if ((rowMesh * colMesh) ==
(meshDevice.num_rows() * meshDevice.num_cols())) {
// Full multi-device storage concatenation
if (shardShape[0] == 1 || shardShape[1] == 1) {
out = ::ttnn::distributed::aggregate_tensor(
input, *::ttnn::distributed::concat_mesh_to_tensor_composer(
(shardShape[1] == 1 ? (lastDim - 1) : lastDim)));
} else {
out = ::ttnn::distributed::aggregate_tensor(
input, *::ttnn::distributed::concat_2d_mesh_to_tensor_composer(
meshDevice, ::ttnn::distributed::Concat2dConfig{
.row_dim = static_cast<int>(lastDim - 1),
.col_dim = static_cast<int>(lastDim)}));
}
} else {
// Partial multi-device storage concatenation
// Current ttnn api does not support partial multi-device storage
// concatenation. Thus, xtensor APIs are being called directly from here.
std::vector<::ttnn::Tensor> target_tensors;
bool transpose = (rowMesh != meshDevice.num_rows());
size_t iteration = (transpose) ? colMesh : rowMesh;
size_t stride =
(transpose) ? meshDevice.num_rows() : meshDevice.num_cols();
for (size_t i = 0; i < iteration; ++i) {
target_tensors.push_back(input_tensors[i * stride]);
}
out = ::ttnn::experimental::xtensor::concat(target_tensors, lastDim - 1);
}
}
}

void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
const ::tt::target::MeshShardDirection shardDirection = op->shard_direction();
const ::tt::target::MeshShardType shardType = op->shard_type();
const auto *fbShardShape = op->shard_shape();
std::vector<int64_t> shardShape(fbShardShape->begin(), fbShardShape->end());

if (shardDirection != ::tt::target::MeshShardDirection::FullToShardShape &&
shardDirection != ::tt::target::MeshShardDirection::ShardToFullShape) {
throw std::runtime_error("Unsupported shard direction");
}

if (shardType != ::tt::target::MeshShardType::Replicate &&
shardType != ::tt::target::MeshShardType::Devices) {
throw std::runtime_error("Unsupported shard type");
}

::ttnn::MeshDevice &meshDevice =
context.getSubMesh(op->device()->global_id());

::ttnn::Tensor out;
if (shardDirection == ::tt::target::MeshShardDirection::FullToShardShape) {
FullToShardShape(input, out, meshDevice, shardType, shardShape);
} else {
ShardToFullShape(input, out, meshDevice, shardType, shardShape);
}
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

} // namespace tt::runtime::ttnn::operations::ccl
15 changes: 15 additions & 0 deletions runtime/lib/ttnn/operations/ccl/mesh_shard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_MESH_SHARD_H
#define TTNN_RUNTIME_MESH_SHARD_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ccl {
void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context);
} // namespace tt::runtime::ttnn::operations::ccl

#endif
Loading

0 comments on commit 2fd39a1

Please sign in to comment.