Skip to content

Commit

Permalink
#16186: update running statistics in batch norm
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw authored and VirdhatchaniKN committed Jan 22, 2025
1 parent dc511a3 commit d38dba5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 21 deletions.
46 changes: 28 additions & 18 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,40 @@
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 32, 32]),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 23, 23]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
# *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
# torch.Size([4, 4, 32, 32]),
# *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
# torch.Size([4, 4, 23, 23]),
# *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
# torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize(
"training, check_mean, check_var",
[
# (True, True, True),
# (True, True, False),
# (True, False, True),
(True, False, False),
(False, False, False), # xfail case
(False, True, False), # xfail case
(False, False, True), # xfail case
(False, True, True),
(True, True, True),
# # (True, True, False),
# # (True, False, True),
# (True, False, False),
# (False, False, False), # xfail case
# (False, True, False), # xfail case
# (False, False, True), # xfail case
# (False, True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("momentum", [0.1, 0.0, 2.3])
@pytest.mark.parametrize("momentum", [0.1, 0.0])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None)
)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None)
print("mean_tensor", mean_tensor)
print("var_tensor", var_tensor)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

Expand All @@ -65,9 +67,8 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)

# tt_updated_mean = ttnn.to_torch(mean_tensor)
# tt_updated_var = ttnn.to_torch(var_tensor)

tt_updated_mean = ttnn.to_torch(mean_tensor)
tt_updated_var = ttnn.to_torch(var_tensor)
# ttnn.set_printoptions(profile="full")
# print("TT result : ", tt_output, tt_output.shape)
# torch.set_printoptions(precision=5, sci_mode=False)
Expand All @@ -81,13 +82,22 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias,
eps=eps,
momentum=momentum,
)
batch_mean = in_data.mean(dim=(0, 2, 3))
batch_var = in_data.var(dim=(0, 2, 3), unbiased=False)
print("Batch mean:", batch_mean)
print("Batch variance:", batch_var)
print("mean_data", mean_data)
print("tt_updated_mean", tt_updated_mean)
print("var_data", var_data)
print("tt_updated_var", tt_updated_var)
# print("Torch result : ",torch_result)
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result
# if training :
# channels = input_shapes[1]
# comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data.view(1, channels, 1, 1)]) # Check Updated running mean
# comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data.view(1, channels, 1, 1)]) # Check Updated running var
# comp_pass = comp_pass and comp_pass_1 and comp_pass_2

assert comp_pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,39 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
b_num_tiles_per_cb,
h_data_format); // updated running var

// Intermediate buffer
auto [one_cb, one_cb_handle] = create_cb(
tt::CBIndex::c_19,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store 1

auto [tmp1_cb, tmp1_cb_handle] = create_cb(
tt::CBIndex::c_29,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store tmp

auto [tmp2_cb, tmp2_cb_handle] = create_cb(
tt::CBIndex::c_30,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store tmp

auto [tmp3_cb, tmp3_cb_handle] = create_cb(
tt::CBIndex::c_31,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store tmp

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ALWI void subtract_bcast_tiles(
cb_push_back(cb_out, onetile);
cb_pop_front(cb_other, onetile);
}
cb_pop_front(cb_bcast, onetile);
// cb_pop_front(cb_bcast, onetile);
}

void MAIN {
Expand Down Expand Up @@ -63,6 +63,10 @@ void MAIN {
constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor
constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor
constexpr auto cb_momentum = tt::CBIndex::c_24; // momentum
constexpr auto cb_one = tt::CBIndex::c_19; // stores 1
constexpr auto cb_tmp1 = tt::CBIndex::c_29; // tmp 1
constexpr auto cb_tmp2 = tt::CBIndex::c_30; // tmp 2
constexpr auto cb_tmp3 = tt::CBIndex::c_31; // tmp 3

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand Down Expand Up @@ -102,7 +106,7 @@ void MAIN {
pack_tile_with_dt(dst0, cb_den);
tile_regs_release();

cb_pop_front(cb_batch_var, 1);
// cb_pop_front(cb_batch_var, 1);
cb_pop_front(cb_eps, 1);
cb_push_back(cb_den, onetile);

Expand All @@ -127,9 +131,18 @@ void MAIN {
if constexpr (is_training_mode) {
// updated running stats
if constexpr (old_running_mean_has_value) {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, tile_id, 0, 0); // momentum * running stats
mul_tiles_to_cb(cb_tmp1, cb_old_running_mean, cb_tmp3, 0, tile_id, 1, 0); // cb_tmp1 * batch stat
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1);
}

if constexpr (old_running_var_has_value) {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, tile_id, 0, 0); // momentum * batch stat
mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, tile_id, 0, 1); // cb_tmp1 * running stats
DPRINT << TSLICE(tt::CBIndex::c_26, 0, SliceRange::hw0_32_16()) << ENDL();
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,19 @@ void kernel_main() {
uint32_t start_t = start_remaining % HtWt;

constexpr auto cb_id_eps = tt::CBIndex::c_4;
constexpr auto cb_id_one = tt::CBIndex::c_19;

cb_reserve_back(cb_id_eps, onetile);
fill_with_val_bfloat16(cb_id_eps, eps);
cb_push_back(cb_id_eps, onetile);

constexpr auto cb_id_momentum = tt::CBIndex::c_24;

union {
float f;
uint32_t u;
} scalar;
scalar.f = 1.0f;
fill_cb_with_value(cb_id_one, scalar.u);
cb_reserve_back(cb_id_momentum, onetile);
fill_with_val_bfloat16(cb_id_momentum, momentum);
cb_push_back(cb_id_momentum, onetile);
Expand Down

0 comments on commit d38dba5

Please sign in to comment.