|
| 1 | +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +#include "mesh_shard.h" |
| 6 | +#include "tt/runtime/detail/logger.h" |
| 7 | +#include "tt/runtime/detail/ttnn.h" |
| 8 | +#include "tt/runtime/ttnn/operations/utils.h" |
| 9 | +#include "tt/runtime/ttnn/utils.h" |
| 10 | +#include "ttnn/distributed/distributed_tensor.hpp" |
| 11 | +#include "ttnn/tensor/xtensor/partition.hpp" |
| 12 | + |
| 13 | +namespace tt::runtime::ttnn::operations::ccl { |
| 14 | + |
| 15 | +void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, |
| 16 | + ::ttnn::MeshDevice &meshDevice, |
| 17 | + const ::tt::target::MeshShardType &shardType, |
| 18 | + const std::vector<int64_t> &shardShape) { |
| 19 | + if (shardType == ::tt::target::MeshShardType::Replicate) { |
| 20 | + out = ::ttnn::distributed::distribute_tensor( |
| 21 | + input, meshDevice, |
| 22 | + *::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice)); |
| 23 | + } else { |
| 24 | + LOG_ASSERT( |
| 25 | + input.get_shape().rank() > 1, |
| 26 | + "Sharding requires higher than 2 dimensional tensor. Tensor rank=", |
| 27 | + input.get_shape().rank()); |
| 28 | + auto rowMesh = static_cast<size_t>(shardShape[0]); |
| 29 | + auto colMesh = static_cast<size_t>(shardShape[1]); |
| 30 | + int lastDim = input.get_shape().rank() - 1; |
| 31 | + LOG_ASSERT((rowMesh * colMesh) > 1, |
| 32 | + "Sharding requires higher than 1 mesh. shardShape ", rowMesh, |
| 33 | + colMesh); |
| 34 | + |
| 35 | + ::ttnn::distributed::Shard2dConfig shard2dConfig; |
| 36 | + // last tile replicate |
| 37 | + if (colMesh == 1) { |
| 38 | + if (rowMesh == meshDevice.num_rows()) { |
| 39 | + shard2dConfig = ::ttnn::distributed::Shard2dConfig{ |
| 40 | + .row_dim = (lastDim - 1), .col_dim = std::nullopt}; |
| 41 | + } else { |
| 42 | + // transpose |
| 43 | + shard2dConfig = ::ttnn::distributed::Shard2dConfig{ |
| 44 | + .row_dim = std::nullopt, .col_dim = (lastDim - 1)}; |
| 45 | + } |
| 46 | + } else { |
| 47 | + shard2dConfig = ::ttnn::distributed::Shard2dConfig{ |
| 48 | + .row_dim = (lastDim - 1), .col_dim = lastDim}; |
| 49 | + } |
| 50 | + |
| 51 | + out = ::ttnn::distributed::distribute_tensor( |
| 52 | + input, meshDevice, |
| 53 | + *::ttnn::distributed::shard_tensor_to_2d_mesh_mapper( |
| 54 | + meshDevice, meshDevice.shape(), shard2dConfig)); |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +void ShardToFullShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, |
| 59 | + ::ttnn::MeshDevice &meshDevice, |
| 60 | + const ::tt::target::MeshShardType &shardType, |
| 61 | + const std::vector<int64_t> &shardShape) { |
| 62 | + std::vector<::ttnn::Tensor> input_tensors = |
| 63 | + ::ttnn::distributed::get_tensors_from_multi_device_storage(input); |
| 64 | + if (shardType == ::tt::target::MeshShardType::Replicate) { |
| 65 | + out = input_tensors[0]; |
| 66 | + } else { |
| 67 | + auto rowMesh = static_cast<size_t>(shardShape[0]); |
| 68 | + auto colMesh = static_cast<size_t>(shardShape[1]); |
| 69 | + int lastDim = input.get_shape().rank() - 1; |
| 70 | + if ((rowMesh * colMesh) == |
| 71 | + (meshDevice.num_rows() * meshDevice.num_cols())) { |
| 72 | + // Full multi-device storage concatenation |
| 73 | + if (shardShape[0] == 1 || shardShape[1] == 1) { |
| 74 | + out = ::ttnn::distributed::aggregate_tensor( |
| 75 | + input, *::ttnn::distributed::concat_mesh_to_tensor_composer( |
| 76 | + (shardShape[1] == 1 ? (lastDim - 1) : lastDim))); |
| 77 | + } else { |
| 78 | + out = ::ttnn::distributed::aggregate_tensor( |
| 79 | + input, *::ttnn::distributed::concat_2d_mesh_to_tensor_composer( |
| 80 | + meshDevice, ::ttnn::distributed::Concat2dConfig{ |
| 81 | + .row_dim = static_cast<int>(lastDim - 1), |
| 82 | + .col_dim = static_cast<int>(lastDim)})); |
| 83 | + } |
| 84 | + } else { |
| 85 | + // Partial multi-device storage concatenation |
| 86 | + // Current ttnn api does not support partial multi-device storage |
| 87 | + // concatenation. Thus, xtensor APIs are being called directly from here. |
| 88 | + std::vector<::ttnn::Tensor> target_tensors; |
| 89 | + bool transpose = (rowMesh != meshDevice.num_rows()); |
| 90 | + size_t iteration = (transpose) ? colMesh : rowMesh; |
| 91 | + size_t stride = |
| 92 | + (transpose) ? meshDevice.num_rows() : meshDevice.num_cols(); |
| 93 | + for (size_t i = 0; i < iteration; ++i) { |
| 94 | + target_tensors.push_back(input_tensors[i * stride]); |
| 95 | + } |
| 96 | + out = ::ttnn::experimental::xtensor::concat(target_tensors, lastDim - 1); |
| 97 | + } |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) { |
| 102 | + ProgramTensorPool &tensorPool = context.getTensorPool(); |
| 103 | + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); |
| 104 | + const ::tt::target::MeshShardDirection shardDirection = op->shard_direction(); |
| 105 | + const ::tt::target::MeshShardType shardType = op->shard_type(); |
| 106 | + const auto *fbShardShape = op->shard_shape(); |
| 107 | + std::vector<int64_t> shardShape(fbShardShape->begin(), fbShardShape->end()); |
| 108 | + |
| 109 | + if (shardDirection != ::tt::target::MeshShardDirection::FullToShardShape && |
| 110 | + shardDirection != ::tt::target::MeshShardDirection::ShardToFullShape) { |
| 111 | + throw std::runtime_error("Unsupported shard direction"); |
| 112 | + } |
| 113 | + |
| 114 | + if (shardType != ::tt::target::MeshShardType::Replicate && |
| 115 | + shardType != ::tt::target::MeshShardType::Devices) { |
| 116 | + throw std::runtime_error("Unsupported shard type"); |
| 117 | + } |
| 118 | + |
| 119 | + ::ttnn::MeshDevice &meshDevice = |
| 120 | + context.getSubMesh(op->device()->global_id()); |
| 121 | + |
| 122 | + ::ttnn::Tensor out; |
| 123 | + if (shardDirection == ::tt::target::MeshShardDirection::FullToShardShape) { |
| 124 | + FullToShardShape(input, out, meshDevice, shardType, shardShape); |
| 125 | + } else { |
| 126 | + ShardToFullShape(input, out, meshDevice, shardType, shardShape); |
| 127 | + } |
| 128 | + tensorPool.insert_or_assign(op->out()->global_id(), out); |
| 129 | +} |
| 130 | + |
| 131 | +} // namespace tt::runtime::ttnn::operations::ccl |
0 commit comments