Skip to content

Commit

Permalink
#0: Write updated running stats
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 10, 2025
1 parent e3e7457 commit ec8c29a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
module,
ttnn::batch_norm,
R"doc(
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved.
Args:
Expand All @@ -24,8 +24,8 @@ void bind_batch_norm_operation(pybind11::module& module) {
Keyword args:
eps (float, optional): Epsilon value. Defaults to `1e-05`.
momentum (float, optional): Momentum value. Defaults to `0.1`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`.
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
a_data_format); // to store input - batch_mean
auto [temp_1_cb, temp_1_cb_handle] =
create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
auto [updated_m_cb, updated_m_cb_handle] = create_cb(
tt::CBIndex::c_27,
program,
all_device_cores,
g_single_tile_size,
b_num_tiles_per_cb,
g_data_format); // updated running mean
auto [updated_v_cb, updated_v_cb_handle] = create_cb(
tt::CBIndex::c_28,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // updated running var

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ void MAIN {
constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor
constexpr auto cb_old_running_mean = tt::CBIndex::c_25; // old running mean tensor
constexpr auto cb_old_running_var = tt::CBIndex::c_26; // old running var tensor
constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor
constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor
constexpr auto cb_momentum = tt::CBIndex::c_24; // momentum

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand Down Expand Up @@ -122,7 +125,7 @@ void MAIN {
cb_push_back(cb_affine_or_out, onetile);

if constexpr (is_training_mode) {
// update running stats here
// updated running stats
if constexpr (old_running_mean_has_value) {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ void kernel_main() {

constexpr bool old_running_mean_has_value = get_compile_time_arg_val(10) == 1;
constexpr bool old_running_var_has_value = get_compile_time_arg_val(11) == 1;
constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27;
constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28;

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
Expand Down Expand Up @@ -149,21 +151,37 @@ void kernel_main() {
// to read running stats value for updation
if constexpr (is_training_mode) {
if constexpr (old_running_mean_has_value) {
// read data
cb_reserve_back(cb_id_old_running_mean, onetile);
uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean);
noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr);
noc_async_read_barrier();
fill_tile_with_first_element_bfloat16(cb_id_old_running_mean);
cb_push_back(cb_id_old_running_mean, onetile);

// write data
cb_wait_front(cb_id_updated_running_mean, onetile);
uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean);
noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_updated_running_mean, onetile);
}

if constexpr (old_running_var_has_value) {
// read data
cb_reserve_back(cb_id_old_running_var, onetile);
uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var);
noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr);
noc_async_read_barrier();
fill_tile_with_first_element_bfloat16(cb_id_old_running_var);
cb_push_back(cb_id_old_running_var, onetile);

// write data
cb_wait_front(cb_id_updated_running_var, onetile);
uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var);
noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_updated_running_var, onetile);
}
}

Expand Down

0 comments on commit ec8c29a

Please sign in to comment.