diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp index 96e16a57ac..87c38c7749 100644 --- a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp @@ -49,9 +49,10 @@ void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, } out = ::ttnn::distributed::distribute_tensor( - input, meshDevice, + input, *::ttnn::distributed::shard_tensor_to_2d_mesh_mapper( - meshDevice, meshDevice.shape(), shard2dConfig)); + meshDevice, meshDevice.shape(), shard2dConfig), + meshDevice); } }