Skip to content

Commit

Permalink
#12253: Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 10, 2025
1 parent 0b6d090 commit 84c1940
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 146 deletions.
33 changes: 19 additions & 14 deletions tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@
)


def data_gen_with_range_batch_norm(input_shapes, low, high, device, is_input=False, required_grad=False):
def data_gen_with_range_batch_norm(
input_shapes,
low,
high,
device,
is_input=False,
required_grad=False,
):
assert high > low, "Incorrect range provided"
torch.manual_seed(213919)
channels = input_shapes[1]
if is_input:
pt_tensor = torch.rand(input_shapes, requires_grad=required_grad).bfloat16() * (high - low) + low
tt_tensor = ttnn.from_torch(
pt_tensor,
device=device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
else:
pt_tensor = torch.rand(channels, requires_grad=required_grad).bfloat16() * (high - low) + low
# pt_tensor = pt_tensor.view(1, channels, 1, 1) # to test each section of TT op
size = input_shapes if is_input else channels
pt_tensor = torch.rand(size, requires_grad=required_grad).bfloat16() * (high - low) + low
reshaped_tensor = pt_tensor
if not is_input:
reshaped_tensor = pt_tensor.view(1, channels, 1, 1).expand(1, channels, 32, 32)
tt_tensor = ttnn.Tensor(reshaped_tensor, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device)
tt_tensor = ttnn.from_torch(
reshaped_tensor,
device=device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
return pt_tensor, tt_tensor


Expand Down
106 changes: 51 additions & 55 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,71 +9,35 @@
data_gen_with_range_batch_norm,
compare_results_batch_norm,
)
from itertools import product


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 2, 32, 32])),
(torch.Size([1, 3, 32, 32])),
(torch.Size([2, 1, 32, 32])),
(torch.Size([2, 2, 32, 32])),
(torch.Size([2, 3, 32, 32])),
(torch.Size([3, 1, 32, 32])),
(torch.Size([3, 2, 32, 32])),
(torch.Size([3, 3, 32, 32])),
(torch.Size([4, 1, 32, 32])),
(torch.Size([4, 2, 32, 32])),
(torch.Size([4, 3, 32, 32])),
(torch.Size([4, 4, 32, 32])),
(torch.Size([1, 1, 23, 23])),
(torch.Size([1, 2, 23, 23])),
(torch.Size([1, 3, 23, 23])),
(torch.Size([2, 1, 23, 23])),
(torch.Size([2, 2, 23, 23])),
(torch.Size([2, 3, 23, 23])),
(torch.Size([3, 1, 23, 23])),
(torch.Size([3, 2, 23, 23])),
(torch.Size([3, 3, 23, 23])),
(torch.Size([4, 1, 23, 23])),
(torch.Size([4, 2, 23, 23])),
(torch.Size([4, 3, 23, 23])),
(torch.Size([4, 4, 23, 23])),
(torch.Size([1, 1, 64, 120])),
(torch.Size([1, 2, 64, 120])),
(torch.Size([1, 3, 64, 120])),
(torch.Size([2, 1, 64, 120])),
(torch.Size([2, 2, 64, 120])),
(torch.Size([2, 3, 64, 120])),
(torch.Size([3, 1, 64, 120])),
(torch.Size([3, 2, 64, 120])),
),
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 32, 32]),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 23, 23]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize("training", [False])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
def test_batch_norm(input_shapes, training, weight, bias, eps, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, True, False)
if not training:
mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device, False, False)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device, False, False)
else:
mean_data = None
mean_tensor = None
var_data = None
var_tensor = None
if weight:
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device, False, False)
else:
weight_data = None
weight_tensor = None
if bias:
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device, False, False)
else:
bias_data = None
bias_tensor = None
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None)
)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
Expand All @@ -100,3 +64,35 @@ def test_batch_norm(input_shapes, training, weight, bias, eps, device):
# print("Torch result : ",torch_result)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
[
torch.Size([3, 2, 32, 32]),
],
)
@pytest.mark.parametrize("mem_layout", [ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.TensorMemoryLayout.HEIGHT_SHARDED])
def test_batch_norm_program_cache_and_default(input_shapes, mem_layout, device):
N, H, W, C = input_shapes
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device)

grid_size = ttnn.CoreGrid(y=1, x=8)
grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
shard_shape = N * H * W // grid_size.x, C // grid_size.y
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False)
sharded_mem_config = ttnn.MemoryConfig(mem_layout, ttnn.types.BufferType.L1, shard_spec)

if mem_layout is not ttnn.TensorMemoryLayout.INTERLEAVED:
pytest.xfail("Input tensors to batch norm must be interleaved")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor, running_mean=mean_tensor, running_var=var_tensor, memory_config=sharded_mem_config
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
torch_result = torch.nn.functional.batch_norm(input=in_data, running_mean=mean_data, running_var=var_data)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
module,
ttnn::batch_norm,
R"doc(
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`.Currently support is provided for inference mode only.
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
Args:
Expand All @@ -28,6 +28,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
weight (ttnn.Tensor, optional): the weight or gamma value. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
output (ttnn.Tensor, optional): Preallocated output tensor to store batch norm result. Defaults to `None`.
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,7 @@
namespace ttnn::operations::normalization {
void BatchNormOperation::validate_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& input = tensor_args.input;
const auto& batch_mean = tensor_args.batch_mean;
const auto& batch_var = tensor_args.batch_var;
const auto& eps = operation_attributes.eps;
const auto& weight = tensor_args.weight;
const auto& bias = tensor_args.bias;

auto& output = tensor_args.output;
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;

check_tensor(input, "batch_norm", "input");
check_tensor(batch_mean, "batch_norm", "batch_mean");
Expand Down Expand Up @@ -49,8 +42,8 @@ void BatchNormOperation::validate_tensors(

// bias (1, C, 1, 1)
if (bias.has_value()) {
TT_FATAL(bias.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
}
}

Expand All @@ -61,13 +54,7 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory

void BatchNormOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
// We don't support sharding for now
const auto& input = tensor_args.input;
const auto& batch_mean = tensor_args.batch_mean;
const auto& batch_var = tensor_args.batch_var;
const auto& weight = tensor_args.weight;
const auto& bias = tensor_args.bias;
const auto& output = tensor_args.output;
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;

TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized");
TT_FATAL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ void set_or_update_runtime_arguments(
const BatchNormOperation::tensor_args_t& tensor_args,
BatchNormOperation::tensor_return_value_t& c,
F handle_args) {
const auto& a = tensor_args.input;
const auto& b = tensor_args.batch_mean;
const auto& d = tensor_args.batch_var;
const auto& e = tensor_args.weight;
const auto& f = tensor_args.bias;
const auto& [a, b, d, e, f, _] = tensor_args;
const auto eps = operation_attributes.eps;

const bool weight_has_value = e.has_value();
Expand Down Expand Up @@ -68,7 +64,7 @@ void set_or_update_runtime_arguments(
num_tiles_per_core = num_tiles_per_core_group_2;
} else {
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 11>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 16>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 14>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
continue;
}
Expand All @@ -90,14 +86,12 @@ void set_or_update_runtime_arguments(
cWt};
handle_args(program, reader_kernel_id, core, reader_runtime_args);

const auto weight_addr = weight_has_value ? e.value().buffer()->address() : 0;
const auto bias_addr = bias_has_value ? f.value().buffer()->address() : 0;
const auto weight_addr = weight_has_value ? e->buffer()->address() : 0;
const auto bias_addr = bias_has_value ? f->buffer()->address() : 0;
std::array writer_runtime_args = {
b.buffer()->address(), // batch mean
d.buffer()->address(), // batch var
static_cast<uint32_t>(weight_has_value),
weight_addr, // weight
static_cast<uint32_t>(bias_has_value),
weight_addr, // weight
bias_addr, // bias
c.buffer()->address(), // output
start_tile_id,
Expand Down Expand Up @@ -132,12 +126,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
using namespace tt;
using namespace tt::tt_metal;

const auto& a = tensor_args.input;
const auto& b = tensor_args.batch_mean;
const auto& d = tensor_args.batch_var;
const auto& eps = operation_attributes.eps;
const auto& e = tensor_args.weight;
const auto& f = tensor_args.bias;
const auto& [a, b, d, e, f, _] = tensor_args;

auto program = CreateProgram();

Expand Down Expand Up @@ -169,13 +158,6 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
uint32_t num_cores_y = compute_with_storage_grid_size.y;
auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1});

Buffer* a_buffer = a.buffer();
Buffer* b_buffer = b.buffer();
Buffer* c_buffer = output.buffer();
Buffer* d_buffer = d.buffer();
Buffer* e_buffer = nullptr;
Buffer* f_buffer = nullptr;

// Number of tiles to store per input CB (double buffer)
constexpr uint32_t num_tiles_per_cb = 2;
uint32_t b_num_tiles_per_cb = num_tiles_per_cb;
Expand Down Expand Up @@ -224,24 +206,12 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
auto [temp_1_cb, temp_1_cb_handle] =
create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto a_is_dram = static_cast<uint32_t>(a_buffer->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b_buffer->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(c_buffer->buffer_type() == tt_metal::BufferType::DRAM);
auto d_is_dram = static_cast<uint32_t>(d_buffer->buffer_type() == tt_metal::BufferType::DRAM);
bool e_is_dram = false;
bool f_is_dram = false;

// weight
if (weight_has_value) {
e_buffer = e->buffer();
e_is_dram = static_cast<uint32_t>(e_buffer->buffer_type() == tt_metal::BufferType::DRAM);
}

// bias
if (bias_has_value) {
f_buffer = f->buffer();
f_is_dram = static_cast<uint32_t>(f_buffer->buffer_type() == tt_metal::BufferType::DRAM);
}
auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto d_is_dram = static_cast<uint32_t>(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM;

// READER KERNEL
auto reader_kernel_id = tt_metal::CreateKernel(
Expand All @@ -255,7 +225,14 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp",
all_device_cores,
tt_metal::WriterDataMovementConfig({b_is_dram, c_is_dram, d_is_dram, e_is_dram, f_is_dram}));
tt_metal::WriterDataMovementConfig(
{b_is_dram,
c_is_dram,
d_is_dram,
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value)}));

// COMPUTE KERNEL
bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void MAIN {
cb_pop_front(cb_den, 1);
cb_push_back(cb_affine_or_out, onetile);

if (weight_has_value) { // result = result * weight
if constexpr (weight_has_value) { // result = result * weight
cb_reserve_back(cb_scaled_output, onetile);
cb_wait_front(cb_affine_or_out, 1);
cb_wait_front(cb_weight, 1);
Expand All @@ -134,7 +134,7 @@ void MAIN {
cb_pop_front(cb_weight, 1);
cb_push_back(cb_scaled_output, onetile);
}
if (bias_has_value) { // result = result + bias
if constexpr (bias_has_value) { // result = result + bias
cb_reserve_back(cb_output_0, 1);
cb_wait_front(cb_tmp_1, 1);
cb_wait_front(cb_bias, 1);
Expand Down
Loading

0 comments on commit 84c1940

Please sign in to comment.