Skip to content

Commit

Permalink
#12253: Update test file
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 12, 2025
1 parent 49ca7f7 commit 258ea9b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 62 deletions.
79 changes: 43 additions & 36 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,42 @@
@pytest.mark.parametrize(
"input_shapes",
[
# *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
# torch.Size([4, 4, 32, 32]),
# *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
# torch.Size([4, 4, 23, 23]),
# *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
# torch.Size([3, 1, 64, 120]),
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 32, 32]),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3]) if not (n == 3 and c == 3)),
torch.Size([4, 4, 23, 23]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize(
"training, check_mean, check_var",
[
(True, True, True),
# # (True, True, False),
# # (True, False, True),
# (True, False, False),
# (False, False, False), # xfail case
# (False, True, False), # xfail case
# (False, False, True), # xfail case
# (False, True, True),
(True, True, False),
(True, False, True),
(True, False, False),
(False, False, False), # xfail case
(False, True, False), # xfail case
(False, False, True), # xfail case
(False, True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("momentum", [0.1, 0.0])
@pytest.mark.parametrize("momentum", [0.1, 0.0, 1.0, 2.3])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None)
)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None)
print("mean_tensor", mean_tensor)
print("var_tensor", var_tensor)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

if (not check_mean) or (not check_var):
if (not training) and ((not check_mean) or (not check_var)):
pytest.xfail("running_mean and running_var must be defined in evaluation mode")

tt_output_tensor_on_device = ttnn.batch_norm(
Expand All @@ -66,12 +64,14 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias,
bias=bias_tensor,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
tt_updated_mean = None
tt_updated_var = None
if training:
if check_mean:
tt_updated_mean = ttnn.to_torch(mean_tensor)
if check_var:
tt_updated_var = ttnn.to_torch(var_tensor)

tt_updated_mean = ttnn.to_torch(mean_tensor)
tt_updated_var = ttnn.to_torch(var_tensor)
# ttnn.set_printoptions(profile="full")
# print("TT result : ", tt_output, tt_output.shape)
# torch.set_printoptions(precision=5, sci_mode=False)
torch_result = torch.nn.functional.batch_norm(
input=in_data,
running_mean=mean_data,
Expand All @@ -82,21 +82,28 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias,
eps=eps,
momentum=momentum,
)
batch_mean = in_data.mean(dim=(0, 2, 3))
batch_var = in_data.var(dim=(0, 2, 3), unbiased=False)
print("Batch mean:", batch_mean)
print("Batch variance:", batch_var)
print("mean_data", mean_data)
print("tt_updated_mean", tt_updated_mean)
print("var_data", var_data)
print("tt_updated_var", tt_updated_var)
# print("Torch result : ",torch_result)
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result
# if training :
# channels = input_shapes[1]
# comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data.view(1, channels, 1, 1)]) # Check Updated running mean
# comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data.view(1, channels, 1, 1)]) # Check Updated running var
# comp_pass = comp_pass and comp_pass_1 and comp_pass_2
if training:
channels = input_shapes[1]
if check_mean:
comp_pass_1 = compare_results_batch_norm(
[tt_updated_mean], [mean_data.view(1, channels, 1, 1)]
) # Check Updated running mean
else:
if tt_updated_mean is not None:
comp_pass_1 = True
else:
comp_pass_1 = False
if check_var:
comp_pass_2 = compare_results_batch_norm(
[tt_updated_var], [var_data.view(1, channels, 1, 1)]
) # Check Updated running var
else:
if tt_updated_var is not None:
comp_pass_2 = True
else:
comp_pass_2 = False
comp_pass = comp_pass and comp_pass_1 and comp_pass_2

assert comp_pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
b_num_tiles_per_cb,
h_data_format); // updated running var

// Intermediate buffer
// Intermediate buffers required for uodation of running stats
auto [one_cb, one_cb_handle] = create_cb(
tt::CBIndex::c_19,
program,
Expand All @@ -271,29 +271,14 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
b_num_tiles_per_cb,
h_data_format); // to store 1

auto [tmp1_cb, tmp1_cb_handle] = create_cb(
tt::CBIndex::c_29,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store tmp
auto [tmp1_cb, tmp1_cb_handle] =
create_cb(tt::CBIndex::c_29, program, all_device_cores, h_single_tile_size, b_num_tiles_per_cb, h_data_format);

auto [tmp2_cb, tmp2_cb_handle] = create_cb(
tt::CBIndex::c_30,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store tmp
auto [tmp2_cb, tmp2_cb_handle] =
create_cb(tt::CBIndex::c_30, program, all_device_cores, h_single_tile_size, b_num_tiles_per_cb, h_data_format);

auto [tmp3_cb, tmp3_cb_handle] = create_cb(
tt::CBIndex::c_31,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // to store tmp
auto [tmp3_cb, tmp3_cb_handle] =
create_cb(tt::CBIndex::c_31, program, all_device_cores, h_single_tile_size, b_num_tiles_per_cb, h_data_format);

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ void MAIN {
cb_pop_front(cb_den, 1);
cb_push_back(cb_affine_or_out, onetile);

// Updation of running stats
if constexpr (is_training_mode) {
// updated running stats
// updated_running_stat = (1 − momentum) × running_stat + momentum × batch_stat
if constexpr (old_running_mean_has_value) {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, tile_id, 0, 0); // momentum * running stats
Expand All @@ -141,7 +142,6 @@ void MAIN {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, tile_id, 0, 0); // momentum * batch stat
mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, tile_id, 0, 1); // cb_tmp1 * running stats
DPRINT << TSLICE(tt::CBIndex::c_26, 0, SliceRange::hw0_32_16()) << ENDL();
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,19 @@ void kernel_main() {

constexpr auto cb_id_eps = tt::CBIndex::c_4;
constexpr auto cb_id_one = tt::CBIndex::c_19;
constexpr auto cb_id_momentum = tt::CBIndex::c_24;

cb_reserve_back(cb_id_eps, onetile);
fill_with_val_bfloat16(cb_id_eps, eps);
cb_push_back(cb_id_eps, onetile);

constexpr auto cb_id_momentum = tt::CBIndex::c_24;
union {
float f;
uint32_t u;
} scalar;
scalar.f = 1.0f;
fill_cb_with_value(cb_id_one, scalar.u);

cb_reserve_back(cb_id_momentum, onetile);
fill_with_val_bfloat16(cb_id_momentum, momentum);
cb_push_back(cb_id_momentum, onetile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void kernel_main() {
cb_push_back(cb_id_bias, onetile);
}

// to read running stats value for updation
// Updation of running stats
if constexpr (is_training_mode) {
if constexpr (old_running_mean_has_value) {
// read data
Expand Down

0 comments on commit 258ea9b

Please sign in to comment.