Skip to content

Commit

Permalink
#16186: Running statistics updates in batch normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Dec 26, 2024
1 parent 82b35a3 commit 772e099
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 27 deletions.
37 changes: 26 additions & 11 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,32 @@ def test_binary_scalar_ops(input_shapes, device):
@pytest.mark.parametrize(
"shapes",
[
[[1, 71, 7, 7], [7, 7]],
[[920, 1, 256], [256]],
[[4, 12, 64, 64], [12, 1, 1]],
[[4, 16, 64, 64], [16, 1, 1]],
[[64, 3, 64, 64], [3, 1, 1]],
[[64, 4, 64, 64], [4, 1, 1]],
[[16, 6, 64, 64], [6, 1, 1]],
[[16, 8, 64, 64], [8, 1, 1]],
[[16, 1], [1, 1, 32]],
# [[1, 71, 7, 7], [7, 7]],
# [[920, 1, 256], [256]],
# [[4, 12, 64, 64], [12, 1, 1]],
# [[4, 16, 64, 64], [16, 1, 1]],
# [[64, 3, 64, 64], [3, 1, 1]],
# [[64, 4, 64, 64], [4, 1, 1]],
# [[16, 6, 64, 64], [6, 1, 1]],
# [[16, 8, 64, 64], [8, 1, 1]],
[[1, 1, 1, 1], [1, 1, 1, 1]],
[[1, 2, 1, 1], [1, 2, 1, 1]],
[[1, 3, 1, 1], [1, 3, 1, 1]],
[[1, 4, 1, 1], [1, 4, 1, 1]],
[[1, 10, 1, 1], [1, 10, 1, 1]],
[[1, 22, 1, 1], [1, 22, 1, 1]],
],
)
def test_unequal_ranks(device, shapes):
torch.manual_seed(0)
torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b
torch_one = torch.ones(shapes[0], dtype=torch.bfloat16)
cb_tmp1 = torch_one - 0.34
cb_tmp2 = torch_input_tensor_a * 0.34
cb_tmp3 = cb_tmp1 * torch_input_tensor_b
torch_output_tensor = cb_tmp2 + cb_tmp3
# torch_output_tensor = torch_output_tensor + torch_output_tensor
input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
Expand All @@ -59,7 +69,12 @@ def test_unequal_ranks(device, shapes):
)
output_tensor = ttnn.experimental.add(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

torch.set_printoptions(linewidth=200, threshold=10000, precision=5, sci_mode=False, edgeitems=17)
# print("torch_input_tensor_a", torch_input_tensor_a)
# print("torch_input_tensor_b", torch_input_tensor_b)
print("torch_output_tensor: ", torch_output_tensor)
print("output_tensor: ", output_tensor)
print("Difference :", torch_output_tensor - output_tensor)
assert output_tensor.shape == torch_output_tensor.shape
assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ void set_or_update_runtime_arguments(
const auto [aN, aC, aHt, aWt] = extract_shape_dims(a);
const auto [bN, bC, bHt, bWt] = b.has_value() ? extract_shape_dims(*b) : std::tuple{1u, 1u, 1u, 1u};
const auto [cN, cC, cHt, cWt] = extract_shape_dims(c);

const auto eps = 0.34f;
union {
float f;
uint32_t u;
} param;
param.f = eps;
uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw();

constexpr bool row_major = true;
Expand All @@ -196,7 +201,7 @@ void set_or_update_runtime_arguments(
} else if (core_group_2.contains(core)) {
num_tiles_per_core = num_tiles_per_core_group_2;
} else {
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 10>{0});
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 11>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 11>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
continue;
Expand All @@ -213,7 +218,8 @@ void set_or_update_runtime_arguments(
cN,
cC,
cHt,
cWt};
cWt,
param.u};
handle_args(program, reader_kernel_id, core, reader_runtime_args);

if (b.has_value()) {
Expand Down Expand Up @@ -306,6 +312,22 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio
auto [c_cb, c_cb_handle] =
create_cb(tt::CBIndex::c_2, program, all_device_cores, c_single_tile_size, num_tiles_per_cb, c_data_format);

// Intermediate buffer
auto [d_cb, d_cb_handle] =
create_cb(tt::CBIndex::c_3, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto [e_cb, e_cb_handle] =
create_cb(tt::CBIndex::c_4, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto [f_cb, f_cb_handle] =
create_cb(tt::CBIndex::c_5, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto [g_cb, g_cb_handle] =
create_cb(tt::CBIndex::c_6, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto [h_cb, h_cb_handle] =
create_cb(tt::CBIndex::c_16, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

// If b is a scalar, we only need one tile in the CB
uint32_t b_num_tiles_per_cb = b_buffer != nullptr ? num_tiles_per_cb : 1;
auto [b_cb, b_cb_handle] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

#include <cstdint>
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "dprint.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp"

namespace NAMESPACE {
void MAIN {
Expand All @@ -13,28 +15,25 @@ void MAIN {
constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_in1 = tt::CBIndex::c_1;
constexpr auto cb_out0 = tt::CBIndex::c_2;
constexpr auto cb_one = tt::CBIndex::c_3;
constexpr auto cb_param = tt::CBIndex::c_4;
constexpr auto cb_tmp1 = tt::CBIndex::c_5;
constexpr auto cb_tmp2 = tt::CBIndex::c_6;
constexpr auto cb_tmp3 = tt::CBIndex::c_16;

binary_op_init_common(cb_in0, cb_in1, cb_out0);
add_tiles_init();

constexpr uint32_t onetile = 1;

for(uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) {
cb_wait_front(cb_in0, onetile);
cb_wait_front(cb_in1, onetile);
cb_reserve_back(cb_out0, onetile);

for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) {
tile_regs_acquire();
add_tiles(cb_in0, cb_in1, 0, 0, 0);
sub_tiles_to_cb(cb_one, cb_param, cb_tmp1, 0, tile_id, 1, 0); // 1 - momentum
mul_tiles_to_cb(cb_param, cb_in0, cb_tmp2, 0, 0, 1, 1); // momentum * running stats
mul_tiles_to_cb(cb_tmp1, cb_in1, cb_tmp3, tile_id, 0, 0, 1); // cb_tmp1 * batch stat
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_out0, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3
tile_regs_commit();

tile_regs_wait();
pack_tile(0, cb_out0);
tile_regs_release();

cb_push_back(cb_out0, onetile);
cb_pop_front(cb_in0, onetile);
cb_pop_front(cb_in1, onetile);
}
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <stdint.h>

#include "dataflow_api.h"
#include "debug/dprint.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0);
Expand All @@ -15,10 +17,13 @@ void kernel_main() {
uint32_t c_stride = get_arg_val<uint32_t>(5);
uint32_t N = get_arg_val<uint32_t>(6);
uint32_t C = get_arg_val<uint32_t>(7);
uint32_t param = get_arg_val<uint32_t>(10);

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr auto cb_id_one = tt::CBIndex::c_3;
constexpr auto cb_id_param = tt::CBIndex::c_4;
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
Expand All @@ -38,6 +43,14 @@ void kernel_main() {
uint32_t next_channel_shift = c_stride - HtWt;
uint32_t next_batch_shift = n_stride - c_stride * C;

union {
float f;
uint32_t u;
} scalar;
scalar.f = 1.0f;
fill_cb_with_value(cb_id_one, scalar.u);
fill_cb_with_value(cb_id_param, param);

uint32_t num_tiles_read = 0;
for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) {
for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_t = 0) {
Expand Down

0 comments on commit 772e099

Please sign in to comment.