Skip to content

Commit

Permalink
#0: Updated return type with vector of tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 9, 2025
1 parent d9bc623 commit b1bc734
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 27 deletions.
28 changes: 27 additions & 1 deletion ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional<MemoryCon
batch_mean, dims.front(), true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
}

Tensor BatchNorm::invoke(
std::vector<std::optional<Tensor>> BatchNorm::invoke(
const Tensor& input,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
Expand All @@ -36,6 +36,8 @@ Tensor BatchNorm::invoke(
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
std::optional<Tensor> updated_running_mean,
std::optional<Tensor> updated_running_var,
const std::optional<MemoryConfig>& memory_config) {
if (training) {
Tensor batch_mean = mean_NHW(input, memory_config);
Expand All @@ -54,6 +56,8 @@ Tensor BatchNorm::invoke(
running_mean,
running_var,
output,
updated_running_mean,
updated_running_var,
memory_config);
}
TT_FATAL(
Expand All @@ -71,6 +75,28 @@ Tensor BatchNorm::invoke(
std::nullopt,
std::nullopt,
output,
std::nullopt,
std::nullopt,
memory_config);
}

OptionalTensors BatchNorm::create_async_optional_output_tensors(
const Tensor& input,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
const bool training,
const float eps,
const float momentum,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
std::optional<Tensor> updated_running_mean,
std::optional<Tensor> updated_running_var,
const std::optional<MemoryConfig>& memory_config) {
return {
std::optional<Tensor>(operation::get_workers_for_op_output({input}, {weight, bias})),
training ? std::optional<Tensor>(operation::get_workers_for_op_output({input}, {weight, bias})) : std::nullopt,
training ? std::optional<Tensor>(operation::get_workers_for_op_output({input}, {weight, bias})) : std::nullopt};
}

} // namespace ttnn::operations::normalization
18 changes: 17 additions & 1 deletion ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace ttnn {
namespace operations::normalization {
struct BatchNorm {
static Tensor invoke(
static std::vector<std::optional<Tensor>> invoke(
const Tensor& input,
std::optional<Tensor> running_mean = std::nullopt,
std::optional<Tensor> running_var = std::nullopt,
Expand All @@ -19,6 +19,22 @@ struct BatchNorm {
std::optional<Tensor> weight = std::nullopt,
std::optional<Tensor> bias = std::nullopt,
std::optional<Tensor> output = std::nullopt,
std::optional<Tensor> updated_running_mean = std::nullopt,
std::optional<Tensor> updated_running_var = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static OptionalTensors create_async_optional_output_tensors(
const Tensor& input,
std::optional<Tensor> running_mean = std::nullopt,
std::optional<Tensor> running_var = std::nullopt,
const bool training = false,
const float eps = 1e-05,
const float momentum = 0.1,
std::optional<Tensor> weight = std::nullopt,
std::optional<Tensor> bias = std::nullopt,
std::optional<Tensor> output = std::nullopt,
std::optional<Tensor> updated_running_mean = std::nullopt,
std::optional<Tensor> updated_running_var = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};
} // namespace operations::normalization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ void bind_batch_norm_operation(pybind11::module& module) {
weight (ttnn.Tensor, optional): the weight or gamma value. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
output (ttnn.Tensor, optional): Preallocated output tensor to store batch norm result. Defaults to `None`.
updated_running_mean (ttnn.Tensor, optional): Preallocated output tensor to store the updated running mean value at Inference mode of operation. Defaults to `None`.
updated_running_var (ttnn.Tensor, optional): Preallocated output tensor to store the updated running variance value at Inference mode of operation. Defaults to `None`.
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
Returns:
ttnn.Tensor: the output tensor.
List of ttnn.Tensor: the output tensor.
)doc",
Expand All @@ -48,6 +51,8 @@ void bind_batch_norm_operation(pybind11::module& module) {
py::arg("weight") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("output") = std::nullopt,
py::arg("updated_running_mean") = std::nullopt,
py::arg("updated_running_var") = std::nullopt,
py::arg("memory_config") = std::nullopt

});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
namespace ttnn::operations::normalization {
void BatchNormOperation::validate_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output, updated_running_mean, updated_running_var] =
tensor_args;

check_tensor(input, "batch_norm", "input");
check_tensor(batch_mean, "batch_norm", "batch_mean");
Expand All @@ -20,6 +21,8 @@ void BatchNormOperation::validate_tensors(
check_tensor(output, "batch_norm", "output");
check_tensor(running_mean, "batch_norm", "running_mean");
check_tensor(running_var, "batch_norm", "running_var");
check_tensor(updated_running_mean, "batch_norm", "updated_running_mean");
check_tensor(updated_running_var, "batch_norm", "updated_running_var");

// input (N, C, H, W)
auto C = input.get_logical_shape()[1];
Expand Down Expand Up @@ -67,6 +70,19 @@ void BatchNormOperation::validate_tensors(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
}

// updated_running_mean (1, C, 1, 1)
if (updated_running_mean.has_value()) {
TT_FATAL(
updated_running_mean.value().get_logical_shape()[1] == C,
"updated_running_mean_shape[-1] must be the same as input's channel size.");
}
// updated_running_var (1, C, 1, 1)
if (updated_running_var.has_value()) {
TT_FATAL(
updated_running_var.value().get_logical_shape()[1] == C,
"updated_running_var_shape[-1] must be the same as input's channel size.");
}
}

BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory(
Expand All @@ -77,7 +93,8 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory
void BatchNormOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
// We don't support sharding for now
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output, updated_running_mean, updated_running_var] =
tensor_args;

TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized");
TT_FATAL(
Expand Down Expand Up @@ -139,20 +156,72 @@ DataType BatchNormOperation::operation_attributes_t::get_dtype() const {
BatchNormOperation::spec_return_value_t BatchNormOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
using namespace tt::constants;
std::vector<std::optional<TensorSpec>> result;
result.reserve(3);

// output shape
const auto output_shape = tensor_args.input.get_logical_shape();
return TensorSpec(
result.push_back(TensorSpec(
output_shape,
TensorLayout(operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config));
TensorLayout(operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config)));

const auto C = output_shape[1];
SimpleShape mean_var_shape({1, C, 1, 1});

// updated running mean
if (tensor_args.updated_running_mean.has_value()) {
result.push_back(TensorSpec(
mean_var_shape,
TensorLayout(
operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config)));
} else {
result.push_back(std::nullopt);
}

// updated running var
if (tensor_args.updated_running_var.has_value()) {
result.push_back(TensorSpec(
mean_var_shape,
TensorLayout(
operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config)));
} else {
result.push_back(std::nullopt);
}

return result;
}

BatchNormOperation::tensor_return_value_t BatchNormOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& output_tensor = tensor_args.output;
const auto output_specs = compute_output_specs(operation_attributes, tensor_args);
auto device = tensor_args.input.device();

std::vector<std::optional<Tensor>> result;
result.reserve(3);

// output
if (output_tensor.has_value()) {
return output_tensor.value();
result.push_back(output_tensor.value());
} else {
result.push_back(create_device_tensor(*output_specs[0], device));
}

// updated_running_mean
if (tensor_args.updated_running_mean.has_value()) {
result.push_back(tensor_args.updated_running_mean.value());
} else {
result.push_back(create_device_tensor(*output_specs[1], device));
}

// updated_running_var
if (tensor_args.updated_running_var.has_value()) {
result.push_back(tensor_args.updated_running_var.value());
} else {
result.push_back(create_device_tensor(*output_specs[2], device));
}

return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device());
return result;
}

std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tensor_args_t> BatchNormOperation::invoke(
Expand All @@ -167,6 +236,8 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
std::optional<Tensor> updated_running_mean,
std::optional<Tensor> updated_running_var,
const std::optional<MemoryConfig>& memory_config) {
operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args{
Expand All @@ -177,7 +248,9 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
std::move(bias),
std::move(running_mean),
std::move(running_var),
std::move(output)};
std::move(output),
std::move(updated_running_mean),
std::move(updated_running_var)};
return {operation_attributes, tensor_args};
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ struct BatchNormOperation {
std::optional<Tensor> output;
std::optional<Tensor> running_mean;
std::optional<Tensor> running_var;
std::optional<Tensor> updated_running_mean;
std::optional<Tensor> updated_running_var;
};

using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;
using spec_return_value_t = std::vector<std::optional<TensorSpec>>;
using tensor_return_value_t = std::vector<std::optional<Tensor>>;

struct BatchNormFactory {
struct shared_variables_t {
Expand Down Expand Up @@ -77,6 +79,8 @@ struct BatchNormOperation {
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
std::optional<Tensor> updated_running_mean,
std::optional<Tensor> updated_running_var,
const std::optional<MemoryConfig>& memory_config);
};
} // namespace ttnn::operations::normalization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void set_or_update_runtime_arguments(
const BatchNormOperation::tensor_args_t& tensor_args,
BatchNormOperation::tensor_return_value_t& c,
F handle_args) {
const auto& [a, b, d, e, f, g, h, _] = tensor_args;
const auto& [a, b, d, e, f, g, h, _, updated_m, updated_v] = tensor_args;
const auto eps = operation_attributes.eps;
const auto momentum = operation_attributes.momentum;

Expand All @@ -41,13 +41,13 @@ void set_or_update_runtime_arguments(

const auto ashape = a.padded_shape();
const auto bshape = b.padded_shape();
const auto cshape = c.padded_shape();
const auto cshape = c[0]->padded_shape();

const auto [aN, aC, aHt, aWt] = extract_shape_dims(a);
const auto [bN, bC, bHt, bWt] = extract_shape_dims(b);
const auto [cN, cC, cHt, cWt] = extract_shape_dims(c);
const auto [cN, cC, cHt, cWt] = extract_shape_dims(a); // output same as input

uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw();
uint32_t num_output_tiles = c[0]->volume() / c[0]->tensor_spec().tile().get_tile_hw();

constexpr bool row_major = true;
uint32_t num_cores_x = compute_with_storage_grid_size.x;
Expand Down Expand Up @@ -99,13 +99,13 @@ void set_or_update_runtime_arguments(
const auto running_mean_addr = is_training_mode and running_mean_has_value ? g->buffer()->address() : 0;
const auto running_var_addr = is_training_mode and running_var_has_value ? h->buffer()->address() : 0;
std::array writer_runtime_args = {
b.buffer()->address(), // batch mean
d.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
running_mean_addr, // old running mean
running_var_addr, // old running var
c.buffer()->address(), // output
b.buffer()->address(), // batch mean
d.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
running_mean_addr, // old running mean
running_var_addr, // old running var
c[0]->buffer()->address(), // output
start_tile_id,
num_tiles_per_core,
cHtWt,
Expand Down Expand Up @@ -138,7 +138,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
using namespace tt;
using namespace tt::tt_metal;

const auto& [a, b, d, e, f, g, h, _] = tensor_args;
const auto& [a, b, d, e, f, g, h, _, updated_m, updated_v] = tensor_args;

auto program = CreateProgram();

Expand All @@ -152,7 +152,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch

auto a_data_format = datatype_to_dataformat_converter(a.get_dtype());
auto b_data_format = datatype_to_dataformat_converter(b.get_dtype());
auto c_data_format = datatype_to_dataformat_converter(output.get_dtype());
auto c_data_format = datatype_to_dataformat_converter(output[0]->get_dtype());
auto d_data_format = datatype_to_dataformat_converter(d.get_dtype());
auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b;
auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b;
Expand All @@ -170,7 +170,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
uint32_t g_single_tile_size = tt_metal::detail::TileSize(g_data_format);
uint32_t h_single_tile_size = tt_metal::detail::TileSize(h_data_format);

uint32_t num_output_tiles = output.volume() / output.tensor_spec().tile().get_tile_hw();
uint32_t num_output_tiles = output[0]->volume() / output[0]->tensor_spec().tile().get_tile_hw();

// we parallelize the computation across the output tiles
constexpr bool row_major = true;
Expand Down Expand Up @@ -250,7 +250,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto c_is_dram = static_cast<uint32_t>(output[0]->buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto d_is_dram = static_cast<uint32_t>(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
Expand Down

0 comments on commit b1bc734

Please sign in to comment.