Skip to content
Merged
95 changes: 93 additions & 2 deletions onnxruntime/core/optimizer/nchwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
void TransformConv(Node& node);
void TransformPool(Node& node);
void TransformBinary(Node& node, bool add_node);
void TransformMul(Node& node);
void TransformConcat(Node& node);
void TransformActivation(Node& node);
void TransformBatchNormalization(Node& node);
Expand Down Expand Up @@ -345,7 +346,7 @@
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_output_channels = (output_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 349 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

bool do_reorder_input = true;
bool reorder_filter_OIHWBo = false;
Expand Down Expand Up @@ -378,7 +379,7 @@
if ((input_channels % channel_alignment) != 0) {
return;
}
filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 382 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size
}
}

Expand Down Expand Up @@ -734,6 +735,89 @@
}
}

void NchwcTransformerImpl::TransformMul(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();

if (input_defs.size() != 2 || output_defs.size() != 1) {
return;
}

auto* nchwc_input_0 = LookupNchwcArgument(input_defs[0]);
auto* nchwc_input_1 = LookupNchwcArgument(input_defs[1]);

// If both inputs are already NCHWc arguments, use the regular binary path.
if (nchwc_input_0 != nullptr && nchwc_input_1 != nullptr) {
TransformBinary(node, false);
return;
}

// Exactly one input must be NCHWc and the other must be a static scale tensor.
if ((nchwc_input_0 == nullptr) == (nchwc_input_1 == nullptr)) {
return;
}

const int nchwc_input_index = (nchwc_input_0 != nullptr) ? 0 : 1;
const int scale_input_index = nchwc_input_index ^ 1;
auto* nchwc_input = (nchwc_input_index == 0) ? nchwc_input_0 : nchwc_input_1;

const auto* mul_scale_tensor_proto = graph_utils::GetConstantInitializer(graph_, input_defs[scale_input_index]->Name());
if (mul_scale_tensor_proto == nullptr ||
mul_scale_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return;
}

const int64_t channels = nchwc_input->channels_;

Initializer mul_scale{graph_, *mul_scale_tensor_proto, graph_.ModelPath()};
const auto scale_dims = mul_scale.dims();

bool channel_broadcast_shape = false;
if (scale_dims.size() == 3) {
channel_broadcast_shape = (scale_dims[0] == channels && scale_dims[1] == 1 && scale_dims[2] == 1);
} else if (scale_dims.size() == 4) {
channel_broadcast_shape = (scale_dims[0] == 1 && scale_dims[1] == channels && scale_dims[2] == 1 && scale_dims[3] == 1);
}

if (!channel_broadcast_shape) {
return;
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + static_cast<int64_t>(nchwc_block_size) - 1) & ~static_cast<int64_t>(nchwc_block_size - 1);

InlinedVector<float> padded_scale(gsl::narrow<size_t>(nchwc_channels));
Comment thread
hariharans29 marked this conversation as resolved.
Outdated
std::copy_n(mul_scale.data<float>(), channels, padded_scale.data());

ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto;
nchwc_conv_W_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
nchwc_conv_W_tensor_proto.set_name(graph_.GenerateNodeArgName("mul_scale"));
utils::SetRawDataInTensorProto(nchwc_conv_W_tensor_proto, padded_scale.data(),
gsl::narrow<size_t>(nchwc_channels) * sizeof(float));
nchwc_conv_W_tensor_proto.add_dims(nchwc_channels);
nchwc_conv_W_tensor_proto.add_dims(1);
nchwc_conv_W_tensor_proto.add_dims(1);
nchwc_conv_W_tensor_proto.add_dims(1);

auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithOrtValue(graph_, nchwc_conv_W_tensor_proto);

std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_mul_nchwc");
Node& nchwc_node = graph_.AddNode(nchwc_node_name,
"Conv",
nchwc_node_name,
std::array{nchwc_input->nchwc_arg_, nchwc_conv_W_arg},
output_defs,
nullptr,
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
nchwc_node.AddAttribute("group", nchwc_channels);

nchwc_input->remaining_original_uses_--;

CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_);
removed_nodes_.push_front(node.Index());
}

void NchwcTransformerImpl::TransformConcat(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();
Expand Down Expand Up @@ -794,7 +878,12 @@
// Check if this is a single use NCHWc convolution that hasn't already
// been fused with another activation.
auto& nchwc_node = nchwc_input->output_node_;

const bool can_fuse_activation = (node.OpType() == "Relu") ||
(node.OpType() == "Sigmoid") ||
(node.OpType() == "Tanh");
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
can_fuse_activation &&
(nchwc_input->starting_original_uses_ == 1) &&
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
nchwc_node.AddAttribute("activation", node.OpType());
Expand Down Expand Up @@ -880,7 +969,7 @@
bn_B.sub(bn_mean);

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 972 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

InlinedVector<float> padded_buffer(gsl::narrow<size_t>(nchwc_channels));

Expand Down Expand Up @@ -1169,12 +1258,14 @@
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) {
TransformBinary(node, true);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) {
TransformBinary(node, false);
TransformMul(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
TransformConcat(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) {
TransformActivation(node);
Comment thread
hariharans29 marked this conversation as resolved.
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) {
TransformBatchNormalization(node);
Expand Down
57 changes: 47 additions & 10 deletions onnxruntime/test/optimizer/nchwc_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,15 @@ struct NchwcTestHelper {

Node& AddNode(const std::string& op_type,
const std::vector<NodeArg*>& input_args,
const std::vector<NodeArg*>& output_args) {
const std::vector<NodeArg*>& output_args,
const std::string& domain = kOnnxDomain) {
return graph_.AddNode(graph_.GenerateNodeName("node"),
op_type,
"description",
input_args,
output_args);
output_args,
nullptr,
domain);
}

Node& AddConvNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector<int64_t>& weights_shape, bool no_bias = false) {
Expand Down Expand Up @@ -706,6 +709,36 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) {
}
}

TEST(NchwcOptimizerTests, ConvMulChannelScale) {
auto test_case = [&](const std::vector<int64_t>& scale_shape) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 32, 25, 21});
auto* conv_output_arg = helper.MakeIntermediate();
auto* mul_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();

helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3});
auto* scale_arg = helper.MakeInitializer<float>(scale_shape, helper.FillRandomData<float>(scale_shape));
helper.AddNode("Mul", {conv_output_arg, scale_arg}, {mul_output_arg});
helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1});
Comment thread
hariharans29 marked this conversation as resolved.
Outdated
};

auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 3);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count["Mul"], 0);
};

NchwcOptimizerTester(build_test_case, check_nchwc_graph);
};

// Valid ONNX channel broadcast forms for NCHW tensors.
test_case({32, 1, 1});
test_case({1, 32, 1, 1});
}

TEST(NchwcOptimizerTests, ConvConcat) {
auto test_case = [&](int axis, int channel_count, int reorder_output_count) {
auto build_test_case = [&](NchwcTestHelper& helper) {
Expand Down Expand Up @@ -1342,7 +1375,7 @@ TEST(NchwcOptimizerTests, UpsampleLinear) {
}

TEST(NchwcOptimizerTests, Activation) {
auto test_case = [&](const std::string& activation_op_type) {
auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 48, 11, 15});
auto* conv1_output_arg = helper.MakeIntermediate();
Expand All @@ -1351,7 +1384,7 @@ TEST(NchwcOptimizerTests, Activation) {
auto* output_arg = helper.MakeOutput();

helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3});
helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg});
helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain);
helper.AddNode("Add", {conv1_output_arg, activation_output_arg}, {mul_output_arg});
helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1});
};
Expand All @@ -1368,12 +1401,16 @@ TEST(NchwcOptimizerTests, Activation) {
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
};

// Verify that the optimizer doesn't add reorders for these activations that
// cannot be fused with a convolution.
std::vector<std::string> activation_op_types{"Relu", "Sigmoid", "Tanh"};
for (auto& activation_op_type : activation_op_types) {
test_case(activation_op_type);
}
// Verify that the optimizer doesn't add reorders for these activations in
// this pattern. Relu/Sigmoid/Tanh are generally fusable with a preceding
// convolution, but not here because the Conv output is consumed both by the
// activation node and directly by the Add node. Gelu/QuickGelu are also
// expected to remain as separate nodes.
test_case("Relu");
test_case("Sigmoid");
test_case("Tanh");
test_case("Gelu", kMSDomain);
test_case("QuickGelu", kMSDomain);
Comment thread
hariharans29 marked this conversation as resolved.
Outdated
Comment thread
hariharans29 marked this conversation as resolved.
}

TEST(NchwcOptimizerTests, MaxPoolTypeCheck) {
Expand Down
Loading