Skip to content

Commit

Permalink
#12253: Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 27, 2025
1 parent 2acd48f commit 3ca0581
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ void set_or_update_runtime_arguments(
const BatchNormOperation::tensor_args_t& tensor_args,
BatchNormOperation::tensor_return_value_t& c,
F handle_args) {
const auto& [a, b, d, e, f, _] = tensor_args;
const auto& [input_tensor, batch_mean_tensor, batch_var_tensor, weight_tensor, bias_tensor, _] = tensor_args;
const auto eps = operation_attributes.eps;

const bool weight_has_value = e.has_value();
const bool bias_has_value = f.has_value();
const bool weight_has_value = weight_tensor.has_value();
const bool bias_has_value = bias_tensor.has_value();

const auto ashape = a.padded_shape();
const auto bshape = b.padded_shape();
const auto ashape = input_tensor.padded_shape();
const auto bshape = batch_mean_tensor.padded_shape();
const auto cshape = c.padded_shape();

const auto [aN, aC, aHt, aWt] = extract_shape_dims(a);
const auto [bN, bC, bHt, bWt] = extract_shape_dims(b);
const auto [aN, aC, aHt, aWt] = extract_shape_dims(input_tensor);
const auto [bN, bC, bHt, bWt] = extract_shape_dims(batch_mean_tensor);
const auto [cN, cC, cHt, cWt] = extract_shape_dims(c);

uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw();
Expand All @@ -54,6 +54,9 @@ void set_or_update_runtime_arguments(
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);
constexpr size_t num_reader_args = 11;
constexpr size_t num_writer_args = 14;
constexpr size_t num_kernel_args = 3;
for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) {
const auto& core = cores[i];

Expand All @@ -63,9 +66,9 @@ void set_or_update_runtime_arguments(
} else if (core_group_2.contains(core)) {
num_tiles_per_core = num_tiles_per_core_group_2;
} else {
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 11>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 14>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
handle_args(program, reader_kernel_id, core, std::array<uint32_t, num_reader_args>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, num_writer_args>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, num_kernel_args>{0});
continue;
}

Expand All @@ -74,7 +77,7 @@ void set_or_update_runtime_arguments(
uint32_t packed_scalar_eps = pack_two_bfloat16_into_uint32({bfloat_scalar_eps, bfloat_scalar_eps});
std::array reader_runtime_args = {
packed_scalar_eps,
a.buffer()->address(),
input_tensor.buffer()->address(),
start_tile_id,
num_tiles_per_core,
cHtWt,
Expand All @@ -86,14 +89,14 @@ void set_or_update_runtime_arguments(
cWt};
handle_args(program, reader_kernel_id, core, reader_runtime_args);

const auto weight_addr = weight_has_value ? e->buffer()->address() : 0;
const auto bias_addr = bias_has_value ? f->buffer()->address() : 0;
const auto weight_addr = weight_has_value ? weight_tensor->buffer()->address() : 0;
const auto bias_addr = bias_has_value ? bias_tensor->buffer()->address() : 0;
std::array writer_runtime_args = {
b.buffer()->address(), // batch mean
d.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
c.buffer()->address(), // output
batch_mean_tensor.buffer()->address(), // batch mean
batch_var_tensor.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
c.buffer()->address(), // output
start_tile_id,
num_tiles_per_core,
cHtWt,
Expand Down Expand Up @@ -126,21 +129,23 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
using namespace tt;
using namespace tt::tt_metal;

const auto& [a, b, d, e, f, _] = tensor_args;
const auto& [input_tensor, batch_mean_tensor, batch_var_tensor, weight_tensor, bias_tensor, _] = tensor_args;

auto program = CreateProgram();

auto* device = a.device();
auto* device = input_tensor.device();

const bool weight_has_value = e.has_value();
const bool bias_has_value = f.has_value();
const bool weight_has_value = weight_tensor.has_value();
const bool bias_has_value = bias_tensor.has_value();

auto a_data_format = datatype_to_dataformat_converter(a.get_dtype());
auto b_data_format = datatype_to_dataformat_converter(b.get_dtype());
auto a_data_format = datatype_to_dataformat_converter(input_tensor.get_dtype());
auto b_data_format = datatype_to_dataformat_converter(batch_mean_tensor.get_dtype());
auto c_data_format = datatype_to_dataformat_converter(output.get_dtype());
auto d_data_format = datatype_to_dataformat_converter(d.get_dtype());
auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b;
auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b;
auto d_data_format = datatype_to_dataformat_converter(batch_var_tensor.get_dtype());
auto e_data_format =
weight_has_value ? datatype_to_dataformat_converter(weight_tensor->get_dtype()) : DataFormat::Float16_b;
auto f_data_format =
bias_has_value ? datatype_to_dataformat_converter(bias_tensor->get_dtype()) : DataFormat::Float16_b;

uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format);
uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format);
Expand Down Expand Up @@ -206,12 +211,12 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
auto [temp_1_cb, temp_1_cb_handle] =
create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto a_is_dram = static_cast<uint32_t>(input_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto d_is_dram = static_cast<uint32_t>(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
auto d_is_dram = static_cast<uint32_t>(batch_var_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
const auto e_is_dram = weight_has_value and weight_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto f_is_dram = bias_has_value and bias_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM;

// READER KERNEL
auto reader_kernel_id = tt_metal::CreateKernel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ void set_or_update_runtime_arguments(
const RunningStatistics::tensor_args_t& tensor_args,
RunningStatistics::tensor_return_value_t& c,
F handle_args) {
const auto& [a, b, d, e] = tensor_args;
const auto& [batch_mean_tensor, batch_var_tensor, running_mean_tensor, running_var_tensor] = tensor_args;
const auto momentum = operation_attributes.momentum;

const bool running_mean_has_value = d.has_value();
const bool running_var_has_value = e.has_value();
const bool running_mean_has_value = running_mean_tensor.has_value();
const bool running_var_has_value = running_var_tensor.has_value();

const auto ashape = a.padded_shape();
const auto bshape = b.padded_shape();
const auto ashape = batch_mean_tensor.padded_shape();
const auto bshape = batch_var_tensor.padded_shape();
const auto cshape = c.padded_shape();

const auto [aN, aC, aHt, aWt] = extract_shape_dims(a);
const auto [bN, bC, bHt, bWt] = extract_shape_dims(b);
const auto [aN, aC, aHt, aWt] = extract_shape_dims(batch_mean_tensor);
const auto [bN, bC, bHt, bWt] = extract_shape_dims(batch_var_tensor);
const auto [cN, cC, cHt, cWt] = extract_shape_dims(c);

uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw();
Expand All @@ -55,6 +55,9 @@ void set_or_update_runtime_arguments(
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);
constexpr size_t num_reader_args = 11;
constexpr size_t num_writer_args = 13;
constexpr size_t num_kernel_args = 3;
for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) {
const auto& core = cores[i];

Expand All @@ -64,9 +67,9 @@ void set_or_update_runtime_arguments(
} else if (core_group_2.contains(core)) {
num_tiles_per_core = num_tiles_per_core_group_2;
} else {
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 11>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 13>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
handle_args(program, reader_kernel_id, core, std::array<uint32_t, num_reader_args>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, num_writer_args>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, num_kernel_args>{0});
continue;
}

Expand All @@ -76,7 +79,7 @@ void set_or_update_runtime_arguments(
pack_two_bfloat16_into_uint32({bfloat_scalar_momentum, bfloat_scalar_momentum});
std::array reader_runtime_args = {
packed_scalar_momentum,
a.buffer()->address(),
batch_mean_tensor.buffer()->address(),
start_tile_id,
num_tiles_per_core,
cHtWt,
Expand All @@ -88,13 +91,13 @@ void set_or_update_runtime_arguments(
cWt};
handle_args(program, reader_kernel_id, core, reader_runtime_args);

const auto running_mean_addr = running_mean_has_value ? d->buffer()->address() : 0;
const auto running_var_addr = running_var_has_value ? e->buffer()->address() : 0;
const auto running_mean_addr = running_mean_has_value ? running_mean_tensor->buffer()->address() : 0;
const auto running_var_addr = running_var_has_value ? running_var_tensor->buffer()->address() : 0;
std::array writer_runtime_args = {
b.buffer()->address(), // batch var
running_mean_addr, // old running mean
running_var_addr, // old running var
c.buffer()->address(), // output
batch_var_tensor.buffer()->address(), // batch var
running_mean_addr, // old running mean
running_var_addr, // old running var
c.buffer()->address(), // output
start_tile_id,
num_tiles_per_core,
cHtWt,
Expand Down Expand Up @@ -128,22 +131,22 @@ RunningStatistics::RunningStatisticsProgramFactory::create(
using namespace tt;
using namespace tt::tt_metal;

const auto& [a, b, d, e] = tensor_args;
const auto& [batch_mean_tensor, batch_var_tensor, running_mean_tensor, running_var_tensor] = tensor_args;

auto program = CreateProgram();

auto* device = a.device();
auto* device = batch_mean_tensor.device();

const bool running_mean_has_value = d.has_value();
const bool running_var_has_value = e.has_value();
const bool running_mean_has_value = running_mean_tensor.has_value();
const bool running_var_has_value = running_var_tensor.has_value();

auto a_data_format = datatype_to_dataformat_converter(a.get_dtype());
auto b_data_format = datatype_to_dataformat_converter(b.get_dtype());
auto a_data_format = datatype_to_dataformat_converter(batch_mean_tensor.get_dtype());
auto b_data_format = datatype_to_dataformat_converter(batch_var_tensor.get_dtype());
auto c_data_format = datatype_to_dataformat_converter(output.get_dtype());
auto d_data_format =
running_mean_has_value ? datatype_to_dataformat_converter(d->get_dtype()) : DataFormat::Float16_b;
auto e_data_format =
running_var_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b;
auto d_data_format = running_mean_has_value ? datatype_to_dataformat_converter(running_mean_tensor->get_dtype())
: DataFormat::Float16_b;
auto e_data_format = running_var_has_value ? datatype_to_dataformat_converter(running_var_tensor->get_dtype())
: DataFormat::Float16_b;

uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format);
uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format);
Expand Down Expand Up @@ -235,11 +238,13 @@ RunningStatistics::RunningStatisticsProgramFactory::create(
auto [tmp3_cb, tmp3_cb_handle] =
create_cb(tt::CBIndex::c_23, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format);

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto a_is_dram = static_cast<uint32_t>(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(batch_var_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
const auto d_is_dram = running_mean_has_value and d->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto e_is_dram = running_var_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto d_is_dram =
running_mean_has_value and running_mean_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto e_is_dram =
running_var_has_value and running_var_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM;

// READER KERNEL
auto reader_kernel_id = tt_metal::CreateKernel(
Expand Down

0 comments on commit 3ca0581

Please sign in to comment.