Skip to content

Commit

Permalink
#16186: Update running statistics in batch normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw authored and VirdhatchaniKN committed Jan 23, 2025
1 parent 942024b commit 3d207a9
Show file tree
Hide file tree
Showing 9 changed files with 783 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("momentum", [0.1, 0.0, 1.0, 2.3])
@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "device/batch_norm_device_operation.hpp"
#include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp"
#include "ttnn/operations/eltwise/unary/device/unary_composite_op.hpp"
#include "device/running_statistics_device_operation.hpp"

using namespace tt::tt_metal;

Expand Down Expand Up @@ -42,6 +43,8 @@ Tensor BatchNorm::invoke(
Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config);
Tensor batch_var =
ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config);
Tensor stats =
ttnn::prim::running_statistics(batch_mean, batch_var, momentum, running_mean, running_var, memory_config);
return ttnn::prim::batch_norm(input, batch_mean, batch_var, eps, weight, bias, output, memory_config);
}
TT_FATAL(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "dprint.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp"

namespace NAMESPACE {
void MAIN {
uint32_t num_tiles = get_arg_val<uint32_t>(0);
constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(0) == 1;
constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(1) == 1;

constexpr auto cb_batch_mean = tt::CBIndex::c_0; // batch mean
constexpr auto cb_batch_var = tt::CBIndex::c_1; // batch var
constexpr auto cb_out0 = tt::CBIndex::c_2;
constexpr auto cb_old_running_mean = tt::CBIndex::c_3; // old running mean tensor
constexpr auto cb_old_running_var = tt::CBIndex::c_4; // old running var tensor
constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor
constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor
constexpr auto cb_momentum = tt::CBIndex::c_5; // momentum
constexpr auto cb_one = tt::CBIndex::c_6; // stores 1
constexpr auto cb_tmp1 = tt::CBIndex::c_21; // tmp 1
constexpr auto cb_tmp2 = tt::CBIndex::c_22; // tmp 2
constexpr auto cb_tmp3 = tt::CBIndex::c_23; // tmp 3

binary_op_init_common(cb_batch_mean, cb_batch_var, cb_out0);
constexpr uint32_t onetile = 1;

for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) {
tile_regs_acquire();
// updated_running_stat = (1 − momentum) × running_stat + momentum × batch_stat
cb_wait_front(cb_one, 1);
cb_wait_front(cb_momentum, 1);

if constexpr (old_running_mean_has_value) {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat
mul_tiles_to_cb(cb_tmp1, cb_old_running_mean, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3
}
if constexpr (old_running_var_has_value) {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat
mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3
}
cb_pop_front(cb_one, 1);
cb_pop_front(cb_momentum, 1);
tile_regs_commit();
tile_regs_wait();
pack_tile(0, cb_out0);
tile_regs_release();
cb_push_back(cb_out0, 1);
}
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"
#include "debug/dprint.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"
#include "cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp"

void kernel_main() {
const auto momentum = get_arg_val<uint32_t>(0);
uint32_t src_addr = get_arg_val<uint32_t>(1); // input tensor
uint32_t start_tile_id = get_arg_val<uint32_t>(2);
uint32_t num_tiles = get_arg_val<uint32_t>(3);
uint32_t HtWt = get_arg_val<uint32_t>(4);
uint32_t n_stride = get_arg_val<uint32_t>(5);
uint32_t c_stride = get_arg_val<uint32_t>(6);
uint32_t N = get_arg_val<uint32_t>(7);
uint32_t C = get_arg_val<uint32_t>(8);

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr auto cb_id_momentum = tt::CBIndex::c_5;
constexpr auto cb_id_one = tt::CBIndex::c_6;
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);
const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
uint32_t start_remaining = start_tile_id % tiles_per_batch;
uint32_t start_c = start_remaining / HtWt;
uint32_t start_t = start_remaining % HtWt;

// this is the INPUT tile offset
uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t;

uint32_t next_channel_shift = c_stride - HtWt;
uint32_t next_batch_shift = n_stride - c_stride * C;

union {
float f;
uint32_t u;
} scalar;
scalar.f = 1.0f;
fill_cb_with_value(cb_id_one, scalar.u);

cb_reserve_back(cb_id_momentum, onetile);
fill_with_val_bfloat16(cb_id_momentum, momentum);
cb_push_back(cb_id_momentum, onetile);

uint32_t num_tiles_read = 0;
for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) {
for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_t = 0) {
for (uint32_t t = start_t; t < HtWt && num_tiles_read < num_tiles; ++t, ++num_tiles_read, ++tile_offset) {
cb_reserve_back(cb_id_src, onetile);
uint32_t l1_write_addr_src = get_write_ptr(cb_id_src);
noc_async_read_tile(tile_offset, src, l1_write_addr_src);
noc_async_read_barrier();
cb_push_back(cb_id_src, onetile);
}
tile_offset += next_channel_shift;
}
tile_offset += next_batch_shift;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"
#include "cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp"

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0); // batch_var
uint32_t old_running_mean_addr = get_arg_val<uint32_t>(1); // old running_mean
uint32_t old_running_var_addr = get_arg_val<uint32_t>(2); // ols running_var
uint32_t dst_addr = get_arg_val<uint32_t>(3); // output
uint32_t start_tile_id = get_arg_val<uint32_t>(4);
uint32_t num_tiles = get_arg_val<uint32_t>(5);
uint32_t HtWt = get_arg_val<uint32_t>(6);
uint32_t n_stride = get_arg_val<uint32_t>(7);
uint32_t c_stride = get_arg_val<uint32_t>(8);
uint32_t N = get_arg_val<uint32_t>(9);
uint32_t C = get_arg_val<uint32_t>(10);

constexpr uint32_t onetile = 1;

constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);

const InterleavedAddrGenFast<dst_is_dram> dst = {
.bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format};

// old running mean
constexpr auto cb_id_old_running_mean = tt::CBIndex::c_3;
constexpr bool old_running_mean_is_dram = get_compile_time_arg_val(2) == 1;
const uint32_t old_running_mean_tile_bytes = get_tile_size(cb_id_old_running_mean);
const DataFormat old_running_mean_data_format = get_dataformat(cb_id_old_running_mean);

const InterleavedAddrGenFast<old_running_mean_is_dram> old_running_mean = {
.bank_base_address = old_running_mean_addr,
.page_size = old_running_mean_tile_bytes,
.data_format = old_running_mean_data_format};

// old running var
constexpr auto cb_id_old_running_var = tt::CBIndex::c_4;
constexpr bool old_running_var_is_dram = get_compile_time_arg_val(3) == 1;
const uint32_t old_running_var_tile_bytes = get_tile_size(cb_id_old_running_var);
const DataFormat old_running_var_data_format = get_dataformat(cb_id_old_running_var);

const InterleavedAddrGenFast<old_running_var_is_dram> old_running_var = {
.bank_base_address = old_running_var_addr,
.page_size = old_running_var_tile_bytes,
.data_format = old_running_var_data_format};

constexpr bool old_running_mean_has_value = get_compile_time_arg_val(4) == 1;
constexpr bool old_running_var_has_value = get_compile_time_arg_val(5) == 1;
constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27;
constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28;

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
uint32_t start_remaining = start_tile_id % tiles_per_batch;
uint32_t start_c = start_remaining / HtWt;
uint32_t start_t = start_remaining % HtWt;

// this is the INPUT tile offset
uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t;
uint32_t next_channel_shift = c_stride - HtWt;
uint32_t next_batch_shift = n_stride - c_stride * C;

uint32_t num_tiles_written = 0;
for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) {
for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_t = 0) {
for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) {
// read a tile from src
cb_reserve_back(cb_id_src, onetile);
uint32_t l1_write_addr = get_write_ptr(cb_id_src);
noc_async_read_tile(tile_offset, src, l1_write_addr);
noc_async_read_barrier();
cb_push_back(cb_id_src, onetile);

if constexpr (old_running_mean_has_value) {
// read data
cb_reserve_back(cb_id_old_running_mean, onetile);
uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean);
noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr);
noc_async_read_barrier();
fill_tile_with_first_element_bfloat16(cb_id_old_running_mean);
cb_push_back(cb_id_old_running_mean, onetile);

// write data
cb_wait_front(cb_id_updated_running_mean, onetile);
uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean);
noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_updated_running_mean, onetile);
}

if constexpr (old_running_var_has_value) {
// read data
cb_reserve_back(cb_id_old_running_var, onetile);
uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var);
noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr);
noc_async_read_barrier();
fill_tile_with_first_element_bfloat16(cb_id_old_running_var);
cb_push_back(cb_id_old_running_var, onetile);

// write data
cb_wait_front(cb_id_updated_running_var, onetile);
uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var);
noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_updated_running_var, onetile);
}
++tile_offset;

// write a tile to dst, since the dst shape is full, the tile offset simply grows linearly
cb_wait_front(cb_id_dst, onetile);
uint32_t l1_read_addr = get_read_ptr(cb_id_dst);
noc_async_write_tile(start_tile_id + num_tiles_written, dst, l1_read_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_dst, onetile);
}
tile_offset += next_channel_shift;
}
tile_offset += next_batch_shift;
}
}
Loading

0 comments on commit 3d207a9

Please sign in to comment.