diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp index 71fa5b8438..96e16a57ac 100644 --- a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp @@ -17,8 +17,9 @@ void FullToShardShape(const ::ttnn::Tensor &input, ::ttnn::Tensor &out, const std::vector<int64_t> &shardShape) { if (shardType == ::tt::target::MeshShardType::Replicate) { out = ::ttnn::distributed::distribute_tensor( - input, meshDevice, - *::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice)); + input, + *::ttnn::distributed::replicate_tensor_to_mesh_mapper(meshDevice), + meshDevice); } else { LOG_ASSERT( input.get_shape().rank() > 1,