From efddf68406657236a42dfe64a21f5a74f399ddf9 Mon Sep 17 00:00:00 2001 From: Wooseok Lee Date: Fri, 20 Dec 2024 03:17:02 +0000 Subject: [PATCH] Enable multi-device computation in runtime * Allow ttnn runtime operations including reduce_scatter, mesh_shard, and all_gather * Force mesh_shard ops to use system memory because they are host-side operations * Use strongly-typed sharding options for mesh_shard ops * Add Silicon multi-device test cases * Fix bug in determining axis of all_reduce when converting from stableHLO * Fix typo in ttnn workaround pass --- .../Dialect/TTNN/Pipelines/TTNNPipelines.h | 2 +- .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 2 +- include/ttmlir/Target/Common/types.fbs | 12 ++ include/ttmlir/Target/TTNN/program.fbs | 4 +- .../StableHLOToTTIRPatterns.cpp | 12 +- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 2 +- lib/Dialect/TTNN/Transforms/TTNNLayout.cpp | 28 ++-- .../Workarounds/TTNNWorkarounds.cpp | 2 +- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 24 +++- runtime/lib/ttnn/operations/CMakeLists.txt | 2 + .../lib/ttnn/operations/ccl/all_gather.cpp | 18 ++- .../lib/ttnn/operations/ccl/mesh_shard.cpp | 131 ++++++++++++++++++ runtime/lib/ttnn/operations/ccl/mesh_shard.h | 15 ++ .../ttnn/operations/ccl/reduce_scatter.cpp | 40 ++++++ .../lib/ttnn/operations/ccl/reduce_scatter.h | 16 +++ .../ttnn/operations/context/get_device.cpp | 12 +- .../lib/ttnn/operations/layout/to_device.cpp | 1 - runtime/lib/ttnn/program.cpp | 8 ++ test/lit.cfg.py | 3 + test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir | 4 +- test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir | 12 -- test/ttmlir/Silicon/TTNN/ccl/ccl_x2.mlir | 31 +++++ test/ttmlir/Silicon/TTNN/ccl/ccl_x8.mlir | 31 +++++ test/ttmlir/Silicon/TTNN/multi_device.mlir | 27 ++-- third_party/CMakeLists.txt | 3 + 25 files changed, 385 insertions(+), 57 deletions(-) create mode 100644 runtime/lib/ttnn/operations/ccl/mesh_shard.cpp create mode 100644 runtime/lib/ttnn/operations/ccl/mesh_shard.h create mode 100644 runtime/lib/ttnn/operations/ccl/reduce_scatter.cpp create mode 100644 runtime/lib/ttnn/operations/ccl/reduce_scatter.h delete mode 100644 test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir create mode 100644 test/ttmlir/Silicon/TTNN/ccl/ccl_x2.mlir create mode 100644 test/ttmlir/Silicon/TTNN/ccl/ccl_x8.mlir diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index a65f95c6b..3e8e71de8 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -128,7 +128,7 @@ struct TTIRToTTNNBackendPipelineOptions // Option to enable/disable the workaround pass. // - Option layouotWorkaroundsEnabled{ + Option layoutWorkaroundsEnabled{ *this, "enable-layout-workaround-pass", llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)}; diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index a5f83290d..8476964d6 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -36,7 +36,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> { }]; let options = [ - Option<"layouotWorkaroundsEnabled", + Option<"layoutWorkaroundsEnabled", "ttnn-enable-layout-workaround-pass", "bool", /*default=*/"true", "TTNN Layout Workarounds Pass">, diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index 3e7ed425f..d503e4df3 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -74,6 +74,18 @@ enum BufferType: ushort { Trace, } +enum MeshShardDirection: uint32 { + FullToShardShape, + ShardToFullShape, +} + +enum MeshShardType: uint32 { + Manual, + Replicate, + Maximal, + Devices, +} + // TODO (#620): Add other fields like core_ranges, shard orientation etc. table ShardSpec { shard_shape: [int64]; diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index b56cdb39a..0838c629d 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -322,8 +322,8 @@ table MeshShardOp { in: tt.target.TensorRef; out: tt.target.TensorRef; device: tt.target.DeviceRef; - shard_direction: uint32; - shard_type: uint32; + shard_direction: tt.target.MeshShardDirection; + shard_type: tt.target.MeshShardType; shard_shape: [int64]; } diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 8a5a92e31..40574741d 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1058,16 +1058,16 @@ class StableHLOToTTIRAllReduceOpConversionPattern } } - // Algorithm here is to search for the first non-one working dimension + // Algorithm: search for first non-one working dimension from back auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape(); - size_t dim = 0; - for (auto s : replicaGroupsShape) { - if (s != 1) { + size_t dim = replicaGroupsShape.size() - 1; + for (auto s = replicaGroupsShape.rbegin(); s != replicaGroupsShape.rend(); + ++s, --dim) { + if (*s != 1) { break; } - ++dim; } - if (dim > replicaGroupsShape.size()) { + if (dim < 0) { // all one shape, then select the fastest dim dim = replicaGroupsShape.size(); } diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index f1ec29999..0551ca42f 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -67,7 +67,7 @@ void createTTNNPipelineLoweringPasses( void createTTNNPipelineWorkaroundPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { TTNNWorkaroundsOptions workaroundOptions{ - options.layouotWorkaroundsEnabled, + options.layoutWorkaroundsEnabled, options.decompositionWorkaroundsEnabled}; pm.addPass(createTTNNWorkarounds(workaroundOptions)); pm.addPass(mlir::createCanonicalizerPass()); diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index e148b575f..c9e9c0481 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -274,8 +274,9 @@ class TTNNLayoutDPSOperandsRewriter LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const final { - // To layout op is a special case, we don't want to rewrite it - if (mlir::isa(op.getOperation())) { + // To layout and mesh_shard op are special cases not to rewrite them + if (mlir::isa(op.getOperation()) || + mlir::isa(op.getOperation())) { return failure(); } @@ -330,15 +331,15 @@ class TTNNLayoutDPSOperandsRewriter } }; -// Updates the layout of the operands of a func::ReturnOp. -// The intent is to move the result to host. -class TTNNLayoutFuncReturnRewriter - : public OpRewritePattern { +// Updates the layout of the operands of the SrcOp such that +// the operands reside in host memory. +template +class TTNNLayoutForceSystemMemoryRewriter : public OpRewritePattern { public: - TTNNLayoutFuncReturnRewriter(MLIRContext *ctx) - : OpRewritePattern(ctx) {} + TTNNLayoutForceSystemMemoryRewriter(MLIRContext *ctx) + : OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(mlir::func::ReturnOp op, + LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { bool modified = false; for (OpOperand &operand : op->getOpOperands()) { @@ -355,8 +356,6 @@ class TTNNLayoutFuncReturnRewriter } return modified ? success() : failure(); } - -private: }; class TTNNLayout : public impl::TTNNLayoutBase { @@ -387,9 +386,12 @@ class TTNNLayout : public impl::TTNNLayoutBase { // and rewrites its operands and result to have the correct layout // with respect to operand constraints. patterns.add(&getContext()); - // Takes func::Return op and sets layout which will + // Takes func::Return and ttir::MeshShard ops and set layout which will // move it's operands to host - patterns.add(&getContext()); + patterns.add>( + &getContext()); + patterns.add>( + &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); GreedyRewriteConfig config = GreedyRewriteConfig(); config.useTopDownTraversal = true; diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp index 74d527c42..a51166537 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp @@ -435,7 +435,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { runRewritePatterns(std::move(patterns), GreedyRewriteConfig::kNoLimit /*maxIterations*/); } - if (layouotWorkaroundsEnabled) { + if (layoutWorkaroundsEnabled) { RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 055566c24..170944311 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -499,11 +499,31 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) { auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); auto device = getOperandThroughDPSOps(op.getDevice()); + const mlir::tt::MeshShardDirection shardDirection = op.getShardDirection(); + const mlir::tt::MeshShardType shardType = op.getShardType(); llvm::ArrayRef shardShape = op.getShardShape().getShape(); + + ::tt::target::MeshShardDirection meshShardDirection; + if (shardDirection == mlir::tt::MeshShardDirection::FullToShard) { + meshShardDirection = ::tt::target::MeshShardDirection::FullToShardShape; + } else if (shardDirection == mlir::tt::MeshShardDirection::ShardToFull) { + meshShardDirection = ::tt::target::MeshShardDirection::ShardToFullShape; + } else { + llvm_unreachable("unhandled mesh_shard direction"); + } + + ::tt::target::MeshShardType meshShardType; + if (shardType == mlir::tt::MeshShardType::Replicate) { + meshShardType = ::tt::target::MeshShardType::Replicate; + } else if (shardType == mlir::tt::MeshShardType::Devices) { + meshShardType = ::tt::target::MeshShardType::Devices; + } else { + llvm_unreachable("unhandled mesh_shard type"); + } + return ::tt::target::ttnn::CreateMeshShardOp( *cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device), - static_cast(op.getShardDirection()), - static_cast(op.getShardType()), + meshShardDirection, meshShardType, cache.fbb->CreateVector(shardShape)); } diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index fa5cd3c06..d65c01b21 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -4,6 +4,8 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/mesh_shard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp diff --git a/runtime/lib/ttnn/operations/ccl/all_gather.cpp b/runtime/lib/ttnn/operations/ccl/all_gather.cpp index 8c9e7e00c..816d7bbdc 100644 --- a/runtime/lib/ttnn/operations/ccl/all_gather.cpp +++ b/runtime/lib/ttnn/operations/ccl/all_gather.cpp @@ -3,20 +3,30 @@ // SPDX-License-Identifier: Apache-2.0 #include "operations/ccl/all_gather.h" +#include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" +#include "ttnn/operations/ccl/ccl_host_types.hpp" namespace tt::runtime::ttnn::operations::ccl { + void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); - int32_t dim = op->dim(); - int32_t num_links = op->num_links(); + int32_t gatherDim = op->dim(); + int32_t numLinks = op->num_links(); + LOG_ASSERT( + input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE, + "Input of all_gather must be MULTIDEVICE. id:", op->in()->global_id()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); - ::ttnn::Tensor out = - ::ttnn::all_gather(input, dim, num_links, outputMemoryConfig); + ::ttnn::MeshDevice &meshDevice = + context.getSubMesh(op->device()->global_id()); + ::ttnn::Tensor out = ::ttnn::all_gather( + input, gatherDim, 1, meshDevice, numLinks, outputMemoryConfig, + std::nullopt, std::nullopt, ::ttnn::ccl::Topology::Linear); tensorPool.insert_or_assign(op->out()->global_id(), out); } + } // namespace tt::runtime::ttnn::operations::ccl diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp new file mode 100644 index 000000000..3eeba0f90 --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mesh_shard.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include "ttnn/distributed/distributed_tensor.hpp" +#include "ttnn/tensor/xtensor/partition.hpp" + +namespace tt::runtime::ttnn::operations::ccl { + +void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, + ::ttnn::MeshDevice &meshDevice, + const ::tt::target::MeshShardType &shardType, + const std::vector &shardShape) { + if (shardType == ::tt::target::MeshShardType::Replicate) { + out = ::ttnn::distributed::distribute_tensor( + input, meshDevice, + *::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice)); + } else { + LOG_ASSERT( + input.get_shape().rank() > 1, + "Sharding requires higher than 2 dimensional tensor. Tensor rank=", + input.get_shape().rank()); + auto rowMesh = static_cast(shardShape[0]); + auto colMesh = static_cast(shardShape[1]); + int lastDim = input.get_shape().rank() - 1; + LOG_ASSERT((rowMesh * colMesh) > 1, + "Sharding requires higher than 1 mesh. shardShape ", rowMesh, + colMesh); + + ::ttnn::distributed::Shard2dConfig shard2dConfig; + // last tile replicate + if (colMesh == 1) { + if (rowMesh == meshDevice.num_rows()) { + shard2dConfig = ::ttnn::distributed::Shard2dConfig{ + .row_dim = (lastDim - 1), .col_dim = std::nullopt}; + } else { + // transpose + shard2dConfig = ::ttnn::distributed::Shard2dConfig{ + .row_dim = std::nullopt, .col_dim = (lastDim - 1)}; + } + } else { + shard2dConfig = ::ttnn::distributed::Shard2dConfig{ + .row_dim = (lastDim - 1), .col_dim = lastDim}; + } + + out = ::ttnn::distributed::distribute_tensor( + input, meshDevice, + *::ttnn::distributed::shard_tensor_to_2d_mesh_mapper( + meshDevice, meshDevice.shape(), shard2dConfig)); + } +} + +void ShardToFullShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, + ::ttnn::MeshDevice &meshDevice, + const ::tt::target::MeshShardType &shardType, + const std::vector &shardShape) { + std::vector<::ttnn::Tensor> input_tensors = + ::ttnn::distributed::get_tensors_from_multi_device_storage(input); + if (shardType == ::tt::target::MeshShardType::Replicate) { + out = input_tensors[0]; + } else { + auto rowMesh = static_cast(shardShape[0]); + auto colMesh = static_cast(shardShape[1]); + int lastDim = input.get_shape().rank() - 1; + if ((rowMesh * colMesh) == + (meshDevice.num_rows() * meshDevice.num_cols())) { + // Full multi-device storage concatenation + if (shardShape[0] == 1 || shardShape[1] == 1) { + out = ::ttnn::distributed::aggregate_tensor( + input, *::ttnn::distributed::concat_mesh_to_tensor_composer( + (shardShape[1] == 1 ? (lastDim - 1) : lastDim))); + } else { + out = ::ttnn::distributed::aggregate_tensor( + input, *::ttnn::distributed::concat_2d_mesh_to_tensor_composer( + meshDevice, ::ttnn::distributed::Concat2dConfig{ + .row_dim = static_cast(lastDim - 1), + .col_dim = static_cast(lastDim)})); + } + } else { + // Partial multi-device storage concatenation + // Current ttnn api does not support partial multi-device storage + // concatenation. Thus, xtensor APIs are being called directly from here. + std::vector<::ttnn::Tensor> target_tensors; + bool transpose = (rowMesh != meshDevice.num_rows()); + size_t iteration = (transpose) ? colMesh : rowMesh; + size_t stride = + (transpose) ? meshDevice.num_rows() : meshDevice.num_cols(); + for (size_t i = 0; i < iteration; ++i) { + target_tensors.push_back(input_tensors[i * stride]); + } + out = ::ttnn::experimental::xtensor::concat(target_tensors, lastDim - 1); + } + } +} + +void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); + const ::tt::target::MeshShardDirection shardDirection = op->shard_direction(); + const ::tt::target::MeshShardType shardType = op->shard_type(); + const auto *fbShardShape = op->shard_shape(); + std::vector shardShape(fbShardShape->begin(), fbShardShape->end()); + + if (shardDirection != ::tt::target::MeshShardDirection::FullToShardShape && + shardDirection != ::tt::target::MeshShardDirection::ShardToFullShape) { + throw std::runtime_error("Unsupported shard direction"); + } + + if (shardType != ::tt::target::MeshShardType::Replicate && + shardType != ::tt::target::MeshShardType::Devices) { + throw std::runtime_error("Unsupported shard type"); + } + + ::ttnn::MeshDevice &meshDevice = + context.getSubMesh(op->device()->global_id()); + + ::ttnn::Tensor out; + if (shardDirection == ::tt::target::MeshShardDirection::FullToShardShape) { + FullToShardShape(input, out, meshDevice, shardType, shardShape); + } else { + ShardToFullShape(input, out, meshDevice, shardType, shardShape); + } + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + +} // namespace tt::runtime::ttnn::operations::ccl diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.h b/runtime/lib/ttnn/operations/ccl/mesh_shard.h new file mode 100644 index 000000000..840c2bf85 --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.h @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_MESH_SHARD_H +#define TTNN_RUNTIME_MESH_SHARD_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::ccl { +void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::ccl + +#endif diff --git a/runtime/lib/ttnn/operations/ccl/reduce_scatter.cpp b/runtime/lib/ttnn/operations/ccl/reduce_scatter.cpp new file mode 100644 index 000000000..236434e19 --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/reduce_scatter.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "reduce_scatter.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include "ttnn/operations/ccl/ccl_host_types.hpp" +#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp" + +namespace tt::runtime::ttnn::operations::ccl { + +void run(const ::tt::target::ttnn::ReduceScatterOp *op, + ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); + int32_t scatterSplitDim = op->scatter_split_dim(); + int32_t numLinks = op->num_links(); + auto mathOp = + static_cast<::ttnn::operations::reduction::ReduceType>(op->math_op()); + // Reduction in horizontal direction (x-dimension) in linear computation + // config: e.g., For 2x4 mesh, clusterAxis (1) means reduction in horizontal + // direction such as 0,1,2,3 and 4,5,6,7. + int32_t clusterAxis = 1; + LOG_ASSERT(input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE, + "Input of reduce_scatter must be MULTIDEVICE. id:", + op->in()->global_id()); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); + ::ttnn::MeshDevice &meshDevice = + context.getSubMesh(op->device()->global_id()); + ::ttnn::Tensor out = ::ttnn::reduce_scatter( + input, scatterSplitDim, clusterAxis, meshDevice, mathOp, numLinks, + outputMemoryConfig, ::ttnn::ccl::Topology::Linear); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + +} // namespace tt::runtime::ttnn::operations::ccl diff --git a/runtime/lib/ttnn/operations/ccl/reduce_scatter.h b/runtime/lib/ttnn/operations/ccl/reduce_scatter.h new file mode 100644 index 000000000..a67e67344 --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/reduce_scatter.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_REDUCE_SCATTER_H +#define TTNN_RUNTIME_REDUCE_SCATTER_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::ccl { +void run(const ::tt::target::ttnn::ReduceScatterOp *op, + ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::ccl + +#endif diff --git a/runtime/lib/ttnn/operations/context/get_device.cpp b/runtime/lib/ttnn/operations/context/get_device.cpp index 213c810a3..6c96125f9 100644 --- a/runtime/lib/ttnn/operations/context/get_device.cpp +++ b/runtime/lib/ttnn/operations/context/get_device.cpp @@ -46,11 +46,17 @@ void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) { const ::flatbuffers::Vector *deviceIds = op->chip_ids(); std::unordered_set desiredDeviceIds(deviceIds->begin(), deviceIds->end()); - LOG_ASSERT( - subMeshShape->y() == 1, - "Expected mesh row = 1 for get device op, got: ", subMeshShape->y()); LOG_ASSERT(desiredDeviceIds.size() == deviceIds->size(), "Duplicate device ids in get device op"); + + // Re-map mesh if subMeshShape cannot be a submesh of current shape + MeshShape meshShape = meshDevice.shape(); + if (subMeshShape->y() > static_cast(meshShape.num_rows) || + subMeshShape->x() > static_cast(meshShape.num_cols)) { + meshDevice.reshape(MeshShape(subMeshShape->y(), subMeshShape->x())); + LOG_INFO("remapped mesh device shape [", meshDevice.num_rows(), ", ", + meshDevice.num_cols(), "]"); + } std::shared_ptr<::ttnn::MeshDevice> subMesh = createSubMesh(meshDevice, desiredDeviceIds, subMeshShape); context.addSubMesh(op->out()->global_id(), subMesh); diff --git a/runtime/lib/ttnn/operations/layout/to_device.cpp b/runtime/lib/ttnn/operations/layout/to_device.cpp index c885ea530..cac9ada0e 100644 --- a/runtime/lib/ttnn/operations/layout/to_device.cpp +++ b/runtime/lib/ttnn/operations/layout/to_device.cpp @@ -32,5 +32,4 @@ void run(const ::tt::target::ttnn::ToDeviceOp *op, ProgramContext &context) { targetDevice); tensorPool.insert_or_assign(op->out()->global_id(), out); } - } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index a1176253a..8c47bfb20 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 #include "operations/ccl/all_gather.h" +#include "operations/ccl/mesh_shard.h" +#include "operations/ccl/reduce_scatter.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" #include "operations/creation/arange.h" @@ -228,6 +230,12 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::AllGatherOp: { return operations::ccl::run(op->type_as_AllGatherOp(), context); } + case ::tt::target::ttnn::OpType::ReduceScatterOp: { + return operations::ccl::run(op->type_as_ReduceScatterOp(), context); + } + case ::tt::target::ttnn::OpType::MeshShardOp: { + return operations::ccl::run(op->type_as_MeshShardOp(), context); + } case ::tt::target::ttnn::OpType::ArangeOp: { return operations::creation::run(op->type_as_ArangeOp(), context); } diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 74204a8f3..761ababab 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -26,6 +26,9 @@ def set_system_desc_features(system_desc): config.available_features.add(system_desc["chip_descs"][0]["arch"]) if len(system_desc["chip_desc_indices"]) > 1: config.available_features.add("multi-chip") + config.available_features.add( + "multi-chip-x" + str(len(system_desc["chip_desc_indices"])) + ) # name: The name of this test suite. diff --git a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir index 2f488b6c2..78455da0e 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir @@ -3,7 +3,9 @@ module attributes {} { func.func @forward(%arg0: tensor<8192x784xf32>) -> tensor<4096x196xf32> { %0 = tensor.empty() : tensor<4096x196xf32> %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x4>, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<4096x196xf32>) -> tensor<4096x196xf32> - // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] return %1 : tensor<4096x196xf32> } } + +// CHECK: %[[C:.*]] = "ttnn.get_device"[[C:.*]] +// CHECK-NEXT: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir b/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir deleted file mode 100644 index 9e5972e13..000000000 --- a/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=4,1,1" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// UNSUPPORTED: true -// REQUIRES: multi-chip -func.func @forward(%arg0: tensor<1x1x32x32xf32>) -> tensor<1x1x32x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x1x32x128xf32> - // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] - %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32}> : (tensor<1x1x32x32xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> - return %1 : tensor<1x1x32x128xf32> -} diff --git a/test/ttmlir/Silicon/TTNN/ccl/ccl_x2.mlir b/test/ttmlir/Silicon/TTNN/ccl/ccl_x2.mlir new file mode 100644 index 000000000..bb29cdd49 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/ccl/ccl_x2.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=1,2" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// REQUIRES: multi-chip-x2 + +func.func @forward(%arg0: tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> { + %0 = tensor.empty() : tensor<1x1x32x64xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<1x2>, shard_type = #tt.shard_type}> : (tensor<1x1x32x128xf32>, tensor<1x1x32x64xf32>) -> tensor<1x1x32x64xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + %2 = tensor.empty() : tensor<1x1x32x128xf32> + %3 = "ttir.all_gather"(%1, %2) <{dim = 3 : si32}> : (tensor<1x1x32x64xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + %4 = tensor.empty() : tensor<1x1x32x128xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<1>, shard_type = #tt.shard_type}> : (tensor<1x1x32x128xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + return %5 : tensor<1x1x32x128xf32> +} + +func.func @forward2(%arg0: tensor<1x1x32x128xf32>) -> tensor<1x1x32x64xf32> { + %0 = tensor.empty() : tensor<1x1x32x64xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<1x2>, shard_type = #tt.shard_type}> : (tensor<1x1x32x128xf32>, tensor<1x1x32x64xf32>) -> tensor<1x1x32x64xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + %2 = tensor.empty() : tensor<1x1x32x64xf32> + %3 = "ttir.all_reduce"(%1, %2) <{channel_handle = 1 : si32, dim = 3 : si32, reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> : (tensor<1x1x32x64xf32>, tensor<1x1x32x64xf32>) -> tensor<1x1x32x64xf32> + // CHECK: %[[C:.*]] = "ttnn.reduce_scatter"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + %4 = tensor.empty() : tensor<1x1x32x64xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<1>, shard_type = #tt.shard_type}> : (tensor<1x1x32x64xf32>, tensor<1x1x32x64xf32>) -> tensor<1x1x32x64xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + return %5 : tensor<1x1x32x64xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/ccl/ccl_x8.mlir b/test/ttmlir/Silicon/TTNN/ccl/ccl_x8.mlir new file mode 100644 index 000000000..999d922e5 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/ccl/ccl_x8.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=2,4" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// REQUIRES: multi-chip-x8 + +func.func @forward(%arg0: tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32> { + %0 = tensor.empty() : tensor<1x1x128x128xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x4>, shard_type = #tt.shard_type}> : (tensor<1x1x256x512xf32>, tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + %2 = tensor.empty() : tensor<1x1x128x512xf32> + %3 = "ttir.all_gather"(%1, %2) <{dim = 3 : si32}> : (tensor<1x1x128x128xf32>, tensor<1x1x128x512xf32>) -> tensor<1x1x128x512xf32> + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + %4 = tensor.empty() : tensor<1x1x256x512xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x1>, shard_type = #tt.shard_type}> : (tensor<1x1x128x512xf32>, tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + return %5 : tensor<1x1x256x512xf32> +} + +func.func @forward2(%arg0: tensor<1x1x256x512xf32>) -> tensor<1x1x256x128xf32> { + %0 = tensor.empty() : tensor<1x1x128x128xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x4>, shard_type = #tt.shard_type}> : (tensor<1x1x256x512xf32>, tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + %2 = tensor.empty() : tensor<1x1x128x128xf32> + %3 = "ttir.all_reduce"(%1, %2) <{channel_handle = 1 : si32, dim = 3 : si32, reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<1x1x128x128xf32>, tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> + // CHECK: %[[C:.*]] = "ttnn.reduce_scatter"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + %4 = tensor.empty() : tensor<1x1x256x128xf32> + %5 = "ttir.mesh_shard"(%3, %4) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x1>, shard_type = #tt.shard_type}> : (tensor<1x1x128x128xf32>, tensor<1x1x256x128xf32>) -> tensor<1x1x256x128xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + return %5 : tensor<1x1x256x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/multi_device.mlir b/test/ttmlir/Silicon/TTNN/multi_device.mlir index c927c0d2b..c7b3fcdb0 100644 --- a/test/ttmlir/Silicon/TTNN/multi_device.mlir +++ b/test/ttmlir/Silicon/TTNN/multi_device.mlir @@ -1,12 +1,21 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=2,1,1" %s > %t.mlir +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=1,2" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// UNSUPPORTED: true -// REQUIRES: multi-chip -func.func @multiply(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<8x64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32> - return %1 : tensor<8x64x128xf32> +// REQUIRES: multi-chip-x2 + +func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = tensor.empty() : tensor<8192x392xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<1x2>, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<8192x392xf32>) -> tensor<8192x392xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + %2 = tensor.empty() : tensor<392x16384xf32> + %3 = "ttir.mesh_shard"(%arg1, %2) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x1>, shard_type = #tt.shard_type}> : (tensor<784x16384xf32>, tensor<392x16384xf32>) -> tensor<392x16384xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + %4 = tensor.empty() : tensor<8192x16384xf32> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<8192x392xf32>, tensor<392x16384xf32>, tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %6 = tensor.empty() : tensor<8192x16384xf32> + %7 = "ttir.all_reduce"(%5, %6) <{channel_handle = 1 : si32, dim = 1 : si32, reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> : (tensor<8192x16384xf32>, tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %8 = tensor.empty() : tensor<8192x16384xf32> + %9 = "ttir.mesh_shard"(%7, %8) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<1>, shard_type = #tt.shard_type}> : (tensor<8192x16384xf32>, tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + return %9 : tensor<8192x16384xf32> } diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 0fe2eef79..e481f2d72 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -38,6 +38,9 @@ set(TTMETAL_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/magic_enum/4d76fe0a5b27a0e62d6c15976d02b33c54207096/include ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/boost_core/e679bef5c160cf29d0f37d549881dc5f5a58c332/include ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/json/230202b6f5267cbf0c8e5a2f17301964d95f83ff/include + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/xtensor/4a957e26c765b48cbec4a4235fe9e518d5a85d3d/include + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/xtensor-blas/190c3a4314355b67291a7d78b20a2100de3f8f54/include + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/xtl/0918808959d33a292c551b9f014a0e808bc4a95c/include PARENT_SCOPE )