diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp index dba53e2c5ac6..bffcc4ca0d11 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp @@ -26,7 +26,7 @@ inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional> BatchNorm::invoke( const Tensor& input, std::optional running_mean, std::optional running_var, @@ -36,6 +36,8 @@ Tensor BatchNorm::invoke( std::optional weight, std::optional bias, std::optional output, + std::optional updated_running_mean, + std::optional updated_running_var, const std::optional& memory_config) { if (training) { Tensor batch_mean = mean_NHW(input, memory_config); @@ -54,6 +56,8 @@ Tensor BatchNorm::invoke( running_mean, running_var, output, + updated_running_mean, + updated_running_var, memory_config); } TT_FATAL( @@ -71,6 +75,28 @@ Tensor BatchNorm::invoke( std::nullopt, std::nullopt, output, + std::nullopt, + std::nullopt, memory_config); } + +OptionalTensors BatchNorm::create_async_optional_output_tensors( + const Tensor& input, + std::optional running_mean, + std::optional running_var, + const bool training, + const float eps, + const float momentum, + std::optional weight, + std::optional bias, + std::optional output, + std::optional updated_running_mean, + std::optional updated_running_var, + const std::optional& memory_config) { + return { + std::optional(operation::get_workers_for_op_output({input}, {weight, bias})), + training ? std::optional(operation::get_workers_for_op_output({input}, {weight, bias})) : std::nullopt, + training ? std::optional(operation::get_workers_for_op_output({input}, {weight, bias})) : std::nullopt}; +} + } // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp index 5e17388b67d8..fbbd3aebf076 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp @@ -9,7 +9,7 @@ namespace ttnn { namespace operations::normalization { struct BatchNorm { - static Tensor invoke( + static std::vector> invoke( const Tensor& input, std::optional running_mean = std::nullopt, std::optional running_var = std::nullopt, @@ -19,6 +19,22 @@ struct BatchNorm { std::optional weight = std::nullopt, std::optional bias = std::nullopt, std::optional output = std::nullopt, + std::optional updated_running_mean = std::nullopt, + std::optional updated_running_var = std::nullopt, + const std::optional& memory_config = std::nullopt); + + static OptionalTensors create_async_optional_output_tensors( + const Tensor& input, + std::optional running_mean = std::nullopt, + std::optional running_var = std::nullopt, + const bool training = false, + const float eps = 1e-05, + const float momentum = 0.1, + std::optional weight = std::nullopt, + std::optional bias = std::nullopt, + std::optional output = std::nullopt, + std::optional updated_running_mean = std::nullopt, + std::optional updated_running_var = std::nullopt, const std::optional& memory_config = std::nullopt); }; } // namespace operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp index dd1beebfeea8..41479d7a3179 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp @@ -29,11 +29,14 @@ void bind_batch_norm_operation(pybind11::module& module) { weight (ttnn.Tensor, optional): the weight or gamma value. Defaults to `None`. bias (ttnn.Tensor, optional): the bias or beta value. Defaults to `None`. training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode). + output (ttnn.Tensor, optional): Preallocated output tensor to store batch norm result. Defaults to `None`. + updated_running_mean (ttnn.Tensor, optional): Preallocated output tensor to store the updated running mean value at Inference mode of operation. Defaults to `None`. + updated_running_var (ttnn.Tensor, optional): Preallocated output tensor to store the updated running variance value at Inference mode of operation. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. Returns: - ttnn.Tensor: the output tensor. + List of ttnn.Tensor: the output tensor. )doc", @@ -48,6 +51,8 @@ void bind_batch_norm_operation(pybind11::module& module) { py::arg("weight") = std::nullopt, py::arg("bias") = std::nullopt, py::arg("output") = std::nullopt, + py::arg("updated_running_mean") = std::nullopt, + py::arg("updated_running_var") = std::nullopt, py::arg("memory_config") = std::nullopt }); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp index dd3874b5ac58..3144d6bf4f9f 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp @@ -10,7 +10,8 @@ namespace ttnn::operations::normalization { void BatchNormOperation::validate_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args; + const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output, updated_running_mean, updated_running_var] = + tensor_args; check_tensor(input, "batch_norm", "input"); check_tensor(batch_mean, "batch_norm", "batch_mean"); @@ -20,6 +21,8 @@ void BatchNormOperation::validate_tensors( check_tensor(output, "batch_norm", "output"); check_tensor(running_mean, "batch_norm", "running_mean"); check_tensor(running_var, "batch_norm", "running_var"); + check_tensor(updated_running_mean, "batch_norm", "updated_running_mean"); + check_tensor(updated_running_var, "batch_norm", "updated_running_var"); // input (N, C, H, W) auto C = input.get_logical_shape()[1]; @@ -67,6 +70,19 @@ void BatchNormOperation::validate_tensors( running_var.value().get_logical_shape()[1] == C, "running_var_shape[1] must be the same as input's channel size."); } + + // updated_running_mean (1, C, 1, 1) + if (updated_running_mean.has_value()) { + TT_FATAL( + updated_running_mean.value().get_logical_shape()[1] == C, + "updated_running_mean_shape[-1] must be the same as input's channel size."); + } + // updated_running_var (1, C, 1, 1) + if (updated_running_var.has_value()) { + TT_FATAL( + updated_running_var.value().get_logical_shape()[1] == C, + "updated_running_var_shape[-1] must be the same as input's channel size."); + } } BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory( @@ -77,7 +93,8 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory void BatchNormOperation::validate_on_program_cache_miss( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { // We don't support sharding for now - const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args; + const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output, updated_running_mean, updated_running_var] = + tensor_args; TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized"); TT_FATAL( @@ -139,20 +156,72 @@ DataType BatchNormOperation::operation_attributes_t::get_dtype() const { BatchNormOperation::spec_return_value_t BatchNormOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { using namespace tt::constants; + std::vector> result; + result.reserve(3); + + // output shape const auto output_shape = tensor_args.input.get_logical_shape(); - return TensorSpec( + result.push_back(TensorSpec( output_shape, - TensorLayout(operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config)); + TensorLayout(operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config))); + + const auto C = output_shape[1]; + SimpleShape mean_var_shape({1, C, 1, 1}); + + // updated running mean + if (tensor_args.updated_running_mean.has_value()) { + result.push_back(TensorSpec( + mean_var_shape, + TensorLayout( + operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config))); + } else { + result.push_back(std::nullopt); + } + + // updated running var + if (tensor_args.updated_running_var.has_value()) { + result.push_back(TensorSpec( + mean_var_shape, + TensorLayout( + operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config))); + } else { + result.push_back(std::nullopt); + } + + return result; } BatchNormOperation::tensor_return_value_t BatchNormOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { const auto& output_tensor = tensor_args.output; + const auto output_specs = compute_output_specs(operation_attributes, tensor_args); + auto device = tensor_args.input.device(); + + std::vector> result; + result.reserve(3); + + // output if (output_tensor.has_value()) { - return output_tensor.value(); + result.push_back(output_tensor.value()); + } else { + result.push_back(create_device_tensor(*output_specs[0], device)); + } + + // updated_running_mean + if (tensor_args.updated_running_mean.has_value()) { + result.push_back(tensor_args.updated_running_mean.value()); + } else { + result.push_back(create_device_tensor(*output_specs[1], device)); + } + + // updated_running_var + if (tensor_args.updated_running_var.has_value()) { + result.push_back(tensor_args.updated_running_var.value()); + } else { + result.push_back(create_device_tensor(*output_specs[2], device)); } - return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device()); + return result; } std::tuple BatchNormOperation::invoke( @@ -167,6 +236,8 @@ std::tuple running_mean, std::optional running_var, std::optional output, + std::optional updated_running_mean, + std::optional updated_running_var, const std::optional& memory_config) { operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())}; tensor_args_t tensor_args{ @@ -177,7 +248,9 @@ std::tuple output; std::optional running_mean; std::optional running_var; + std::optional updated_running_mean; + std::optional updated_running_var; }; - using spec_return_value_t = TensorSpec; - using tensor_return_value_t = Tensor; + using spec_return_value_t = std::vector>; + using tensor_return_value_t = std::vector>; struct BatchNormFactory { struct shared_variables_t { @@ -77,6 +79,8 @@ struct BatchNormOperation { std::optional running_mean, std::optional running_var, std::optional output, + std::optional updated_running_mean, + std::optional updated_running_var, const std::optional& memory_config); }; } // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 07cc6840a1c1..e722010c1d48 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -29,7 +29,7 @@ void set_or_update_runtime_arguments( const BatchNormOperation::tensor_args_t& tensor_args, BatchNormOperation::tensor_return_value_t& c, F handle_args) { - const auto& [a, b, d, e, f, g, h, _] = tensor_args; + const auto& [a, b, d, e, f, g, h, _, updated_m, updated_v] = tensor_args; const auto eps = operation_attributes.eps; const auto momentum = operation_attributes.momentum; @@ -41,13 +41,13 @@ void set_or_update_runtime_arguments( const auto ashape = a.padded_shape(); const auto bshape = b.padded_shape(); - const auto cshape = c.padded_shape(); + const auto cshape = c[0]->padded_shape(); const auto [aN, aC, aHt, aWt] = extract_shape_dims(a); const auto [bN, bC, bHt, bWt] = extract_shape_dims(b); - const auto [cN, cC, cHt, cWt] = extract_shape_dims(c); + const auto [cN, cC, cHt, cWt] = extract_shape_dims(a); // output same as input - uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw(); + uint32_t num_output_tiles = c[0]->volume() / c[0]->tensor_spec().tile().get_tile_hw(); constexpr bool row_major = true; uint32_t num_cores_x = compute_with_storage_grid_size.x; @@ -99,13 +99,13 @@ void set_or_update_runtime_arguments( const auto running_mean_addr = is_training_mode and running_mean_has_value ? g->buffer()->address() : 0; const auto running_var_addr = is_training_mode and running_var_has_value ? h->buffer()->address() : 0; std::array writer_runtime_args = { - b.buffer()->address(), // batch mean - d.buffer()->address(), // batch var - weight_addr, // weight - bias_addr, // bias - running_mean_addr, // old running mean - running_var_addr, // old running var - c.buffer()->address(), // output + b.buffer()->address(), // batch mean + d.buffer()->address(), // batch var + weight_addr, // weight + bias_addr, // bias + running_mean_addr, // old running mean + running_var_addr, // old running var + c[0]->buffer()->address(), // output start_tile_id, num_tiles_per_core, cHtWt, @@ -138,7 +138,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch using namespace tt; using namespace tt::tt_metal; - const auto& [a, b, d, e, f, g, h, _] = tensor_args; + const auto& [a, b, d, e, f, g, h, _, updated_m, updated_v] = tensor_args; auto program = CreateProgram(); @@ -152,7 +152,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); auto b_data_format = datatype_to_dataformat_converter(b.get_dtype()); - auto c_data_format = datatype_to_dataformat_converter(output.get_dtype()); + auto c_data_format = datatype_to_dataformat_converter(output[0]->get_dtype()); auto d_data_format = datatype_to_dataformat_converter(d.get_dtype()); auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b; auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b; @@ -170,7 +170,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch uint32_t g_single_tile_size = tt_metal::detail::TileSize(g_data_format); uint32_t h_single_tile_size = tt_metal::detail::TileSize(h_data_format); - uint32_t num_output_tiles = output.volume() / output.tensor_spec().tile().get_tile_hw(); + uint32_t num_output_tiles = output[0]->volume() / output[0]->tensor_spec().tile().get_tile_hw(); // we parallelize the computation across the output tiles constexpr bool row_major = true; @@ -250,7 +250,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); - auto c_is_dram = static_cast(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto c_is_dram = static_cast(output[0]->buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto d_is_dram = static_cast(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM); const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM; const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM;