Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12253: Batch Norm support for training mode #16592

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def compare_results(tt_tensor, golden_tensor, pcc=0.99):
return status


def compare_results_batch_norm(tt_tensor, golden_tensor, pcc=0.99):
def compare_results_batch_norm(tt_tensor, golden_tensor, pcc=0.99, stats=False):
status = True
for i in range(len(tt_tensor)):
tt_out_tensor = tt_tensor[i]
Expand All @@ -144,7 +144,11 @@ def compare_results_batch_norm(tt_tensor, golden_tensor, pcc=0.99):
logger.debug(comp_all)
logger.debug(comp_out)
logger.debug(comp_out_res)
status = status & comp_pass & comp_all
if stats:
status = status & comp_all
else:
status = status & comp_pass & comp_all

return status


Expand Down
72 changes: 56 additions & 16 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

Expand All @@ -15,43 +15,61 @@
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 32, 32]),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 23, 23]),
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize("training", [False])
@pytest.mark.parametrize(
"training, check_mean, check_var",
[
(True, True, True),
(True, True, False),
(True, False, True),
(True, False, False),
(False, False, False), # xfail case
(False, True, False), # xfail case
(False, False, True), # xfail case
(False, True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
def test_batch_norm(input_shapes, training, weight, bias, eps, device):
@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None)
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None)
)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

if (not training) and ((not check_mean) or (not check_var)):
pytest.xfail("running_mean and running_var must be defined in evaluation mode")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
running_mean=mean_tensor,
running_var=var_tensor,
training=training,
eps=eps,
momentum=momentum,
weight=weight_tensor,
bias=bias_tensor,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
# ttnn.set_printoptions(profile="full")
# print("TT result : ", tt_output, tt_output.shape)
# torch.set_printoptions(precision=5, sci_mode=False)
tt_updated_mean = None
tt_updated_var = None
if training:
if check_mean:
tt_updated_mean = ttnn.to_torch(mean_tensor)
if check_var:
tt_updated_var = ttnn.to_torch(var_tensor)

torch_result = torch.nn.functional.batch_norm(
input=in_data,
running_mean=mean_data,
Expand All @@ -60,9 +78,31 @@ def test_batch_norm(input_shapes, training, weight, bias, eps, device):
bias=bias_data,
training=training,
eps=eps,
momentum=momentum,
)
# print("Torch result : ",torch_result)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result
if training:
channels = input_shapes[1]
if check_mean:
comp_pass_1 = compare_results_batch_norm(
[tt_updated_mean], [mean_data.view(1, channels, 1, 1)], stats=True
) # Check Updated running mean
else:
if tt_updated_mean is None:
comp_pass_1 = True
else:
comp_pass_1 = False
if check_var:
comp_pass_2 = compare_results_batch_norm(
[tt_updated_var], [var_data.view(1, channels, 1, 1)], stats=True
) # Check Updated running var
else:
if tt_updated_var is None:
comp_pass_2 = True
else:
comp_pass_2 = False
comp_pass = comp_pass and comp_pass_1 and comp_pass_2

assert comp_pass


Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp
Expand Down
40 changes: 29 additions & 11 deletions ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,49 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "batch_norm.hpp"

#include "device/batch_norm_device_operation.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/eltwise/unary/device/unary_composite_op.hpp"
#include "device/running_statistics_device_operation.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::normalization {

inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config) {
auto output_mem_config = memory_config.value_or(input_tensor.memory_config());
ttnn::SmallVector<int> dims = {2, 3};
Tensor mean_hw = ttnn::mean(input_tensor, dims, true);
return ttnn::mean(mean_hw, 0, true);
}

Tensor BatchNorm::invoke(
const Tensor& input,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
const bool training,
const float eps,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
const float momentum,
const std::optional<Tensor>& weight,
const std::optional<Tensor>& bias,
const std::optional<Tensor>& output,
const std::optional<MemoryConfig>& memory_config) {
// TODO: Implementation for training mode is in progress
TT_FATAL((!training), "Support currently provided for inference mode only");
TT_FATAL(
(running_mean.has_value() && running_var.has_value()),
"running_mean and running_var must be defined in evaluation mode");
return ttnn::prim::batch_norm(
input, running_mean.value(), running_var.value(), eps, weight, bias, output, memory_config);
Tensor batch_mean = mean_NHW(input, memory_config);
Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config);
Tensor batch_var = ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config);
if (training) {
Tensor stats =
ttnn::prim::running_statistics(batch_mean, batch_var, momentum, running_mean, running_var, memory_config);
} else {
TT_FATAL(
(running_mean.has_value() && running_var.has_value()),
"running_mean and running_var must be defined in evaluation mode");
batch_mean = running_mean.value();
batch_var = running_var.value();
}
return ttnn::prim::batch_norm(input, batch_mean, batch_var, eps, weight, bias, output, memory_config);
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

Expand All @@ -15,9 +15,10 @@ struct BatchNorm {
std::optional<Tensor> running_var = std::nullopt,
const bool training = false,
const float eps = 1e-05,
std::optional<Tensor> weight = std::nullopt,
std::optional<Tensor> bias = std::nullopt,
std::optional<Tensor> output = std::nullopt,
const float momentum = 0.1,
const std::optional<Tensor>& weight = std::nullopt,
const std::optional<Tensor>& bias = std::nullopt,
const std::optional<Tensor>& output = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};
} // namespace operations::normalization
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

Expand All @@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
module,
ttnn::batch_norm,
R"doc(
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved.


Args:
Expand All @@ -23,8 +23,9 @@ void bind_batch_norm_operation(pybind11::module& module) {

Keyword args:
eps (float, optional): Epsilon value. Defaults to `1e-05`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
momentum (float, optional): Momentum value. Defaults to `0.1`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`.
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
Expand All @@ -44,6 +45,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
py::arg("running_var") = std::nullopt,
py::arg("training") = false,
py::arg("eps") = 1e-05,
py::arg("momentum") = 0.1,
py::arg("weight") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("output") = std::nullopt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

Expand Down
Loading
Loading