Skip to content
Merged
100 changes: 97 additions & 3 deletions onnxruntime/core/optimizer/nchwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
void TransformConv(Node& node);
void TransformPool(Node& node);
void TransformBinary(Node& node, bool add_node);
void TransformMul(Node& node);
void TransformConcat(Node& node);
void TransformActivation(Node& node);
void TransformBatchNormalization(Node& node);
Expand Down Expand Up @@ -345,7 +346,7 @@
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_output_channels = (output_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 349 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

bool do_reorder_input = true;
bool reorder_filter_OIHWBo = false;
Expand Down Expand Up @@ -378,7 +379,7 @@
if ((input_channels % channel_alignment) != 0) {
return;
}
filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 382 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size
}
}

Expand Down Expand Up @@ -734,6 +735,89 @@
}
}

void NchwcTransformerImpl::TransformMul(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();

if (input_defs.size() != 2 || output_defs.size() != 1) {
return;
}

auto* nchwc_input_0 = LookupNchwcArgument(input_defs[0]);
auto* nchwc_input_1 = LookupNchwcArgument(input_defs[1]);

// If both inputs are already NCHWc arguments, use the regular binary path.
if (nchwc_input_0 != nullptr && nchwc_input_1 != nullptr) {
TransformBinary(node, false);
return;
}

// Exactly one input must be NCHWc and the other must be a static scale tensor.
if ((nchwc_input_0 == nullptr) == (nchwc_input_1 == nullptr)) {
return;
}

const int nchwc_input_index = (nchwc_input_0 != nullptr) ? 0 : 1;
const int scale_input_index = nchwc_input_index ^ 1;
auto* nchwc_input = (nchwc_input_index == 0) ? nchwc_input_0 : nchwc_input_1;

const auto* mul_scale_tensor_proto = graph_utils::GetConstantInitializer(graph_, input_defs[scale_input_index]->Name());
if (mul_scale_tensor_proto == nullptr ||
mul_scale_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return;
}

const int64_t channels = nchwc_input->channels_;

Initializer mul_scale{graph_, *mul_scale_tensor_proto, graph_.ModelPath()};
const auto scale_dims = mul_scale.dims();

bool channel_broadcast_shape = false;
if (scale_dims.size() == 3) {
channel_broadcast_shape = (scale_dims[0] == channels && scale_dims[1] == 1 && scale_dims[2] == 1);
} else if (scale_dims.size() == 4) {
channel_broadcast_shape = (scale_dims[0] == 1 && scale_dims[1] == channels && scale_dims[2] == 1 && scale_dims[3] == 1);
}

if (!channel_broadcast_shape) {
return;
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + static_cast<int64_t>(nchwc_block_size) - 1) & ~static_cast<int64_t>(nchwc_block_size - 1);

InlinedVector<float> padded_scale(gsl::narrow<size_t>(nchwc_channels), 1.0f);
std::copy_n(mul_scale.data<float>(), channels, padded_scale.data());

ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto;
nchwc_conv_W_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
nchwc_conv_W_tensor_proto.set_name(graph_.GenerateNodeArgName("mul_scale"));
utils::SetRawDataInTensorProto(nchwc_conv_W_tensor_proto, padded_scale.data(),
gsl::narrow<size_t>(nchwc_channels) * sizeof(float));
nchwc_conv_W_tensor_proto.add_dims(nchwc_channels);
nchwc_conv_W_tensor_proto.add_dims(1);
nchwc_conv_W_tensor_proto.add_dims(1);
nchwc_conv_W_tensor_proto.add_dims(1);

auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithOrtValue(graph_, nchwc_conv_W_tensor_proto);

std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_mul_nchwc");
Node& nchwc_node = graph_.AddNode(nchwc_node_name,
"Conv",
nchwc_node_name,
std::array{nchwc_input->nchwc_arg_, nchwc_conv_W_arg},
output_defs,
nullptr,
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
nchwc_node.AddAttribute("group", nchwc_channels);

nchwc_input->remaining_original_uses_--;

CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_);
removed_nodes_.push_front(node.Index());
}

void NchwcTransformerImpl::TransformConcat(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();
Expand Down Expand Up @@ -794,7 +878,12 @@
// Check if this is a single use NCHWc convolution that hasn't already
// been fused with another activation.
auto& nchwc_node = nchwc_input->output_node_;

const bool can_fuse_activation = (node.OpType() == "Relu") ||
(node.OpType() == "Sigmoid") ||
(node.OpType() == "Tanh");
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
can_fuse_activation &&
(nchwc_input->starting_original_uses_ == 1) &&
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
nchwc_node.AddAttribute("activation", node.OpType());
Expand Down Expand Up @@ -880,7 +969,7 @@
bn_B.sub(bn_mean);

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 972 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

InlinedVector<float> padded_buffer(gsl::narrow<size_t>(nchwc_channels));

Expand Down Expand Up @@ -1159,6 +1248,11 @@
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10, 11, 12, 22}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "AveragePool", {1, 7, 10, 11, 19, 22})) {
TransformPool(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) {
// Mul can be converted when exactly one input is already NCHWc and the
// other input is a static channel-scale tensor, so it does not need to
// wait for all input edges to be removed.
TransformMul(node);
} else if (node.GetInputEdgesCount() == 0 && node.InputDefs().size() != 0) {
// The following transforms only run when the input edge count has already
// been decremented to zero by earlier transforms. This is a hint that the
Expand All @@ -1168,13 +1262,13 @@
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) {
TransformBinary(node, true);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) {
TransformBinary(node, false);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
TransformConcat(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) {
TransformActivation(node);
Comment thread
hariharans29 marked this conversation as resolved.
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) {
TransformBatchNormalization(node);
Expand Down
135 changes: 123 additions & 12 deletions onnxruntime/test/optimizer/nchwc_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,15 @@ struct NchwcTestHelper {

Node& AddNode(const std::string& op_type,
const std::vector<NodeArg*>& input_args,
const std::vector<NodeArg*>& output_args) {
const std::vector<NodeArg*>& output_args,
const std::string& domain = kOnnxDomain) {
return graph_.AddNode(graph_.GenerateNodeName("node"),
op_type,
"description",
input_args,
output_args);
output_args,
nullptr,
domain);
}

Node& AddConvNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector<int64_t>& weights_shape, bool no_bias = false) {
Expand Down Expand Up @@ -168,7 +171,8 @@ struct NchwcTestHelper {

void NchwcOptimizerTester(const std::function<void(NchwcTestHelper& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_nchwc_graph,
int opset_version = 13) {
int opset_version = 13,
const std::function<void(const Graph& graph)>& check_pre_optimization_graph = nullptr) {
// Ignore the test if NCHWc is not supported by the platform.
if (MlasNchwcGetBlockSize() <= 1) {
return;
Expand All @@ -177,12 +181,17 @@ void NchwcOptimizerTester(const std::function<void(NchwcTestHelper& helper)>& bu
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = opset_version;
domain_to_version[kMSDomain] = 1;
Model model("nchwc", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
NchwcTestHelper helper(model.MainGraph());
build_test_case(helper);
ASSERT_STATUS_OK(model.MainGraph().Resolve());

if (check_pre_optimization_graph) {
check_pre_optimization_graph(model.MainGraph());
}

// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);
Expand Down Expand Up @@ -706,6 +715,62 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) {
}
}

TEST(NchwcOptimizerTests, ConvMulChannelScale) {
const int64_t input_channels = static_cast<int64_t>(MlasNchwcGetBlockSize()) * 2;

auto test_case = [&](int64_t output_channels, bool use_explicit_batch_dim, bool scale_first) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, input_channels, 25, 21});
auto* conv_output_arg = helper.MakeIntermediate();
auto* quickgelu_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();

helper.AddConvNode(input_arg, conv_output_arg, {output_channels, input_channels, 3, 3});
helper.AddNode("QuickGelu", {conv_output_arg}, {quickgelu_output_arg}, kMSDomain);
const std::vector<int64_t> scale_shape = use_explicit_batch_dim
? std::vector<int64_t>{1, output_channels, 1, 1}
: std::vector<int64_t>{output_channels, 1, 1};
auto* scale_arg = helper.MakeInitializer<float>(scale_shape, helper.FillRandomData<float>(scale_shape));
if (scale_first) {
helper.AddNode("Mul", {scale_arg, quickgelu_output_arg}, {output_arg});
} else {
helper.AddNode("Mul", {quickgelu_output_arg, scale_arg}, {output_arg});
}
};

auto check_pre_optimization_graph = [&](const Graph& graph) {
auto op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Conv"], 1);
EXPECT_EQ(op_to_count["com.microsoft.QuickGelu"], 1);
EXPECT_EQ(op_to_count["Mul"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 0);
};

auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2);
EXPECT_EQ(op_to_count["com.microsoft.QuickGelu"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count["Mul"], 0);
};

NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph);
};

// Keep Conv input channels aligned so the initial NCHWc transform still
// applies, and vary only the logical output channel count to cover both the
// aligned path and the padded scale path in the Mul rewrite.
for (int64_t output_channels : {input_channels, input_channels + 1}) {
test_case(output_channels, false, false);
test_case(output_channels, false, true);
test_case(output_channels, true, false);
test_case(output_channels, true, true);
}
}

TEST(NchwcOptimizerTests, ConvConcat) {
auto test_case = [&](int axis, int channel_count, int reorder_output_count) {
auto build_test_case = [&](NchwcTestHelper& helper) {
Expand Down Expand Up @@ -1342,7 +1407,7 @@ TEST(NchwcOptimizerTests, UpsampleLinear) {
}

TEST(NchwcOptimizerTests, Activation) {
auto test_case = [&](const std::string& activation_op_type) {
auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 48, 11, 15});
auto* conv1_output_arg = helper.MakeIntermediate();
Expand All @@ -1351,29 +1416,75 @@ TEST(NchwcOptimizerTests, Activation) {
auto* output_arg = helper.MakeOutput();

helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3});
helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg});
helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain);
helper.AddNode("Add", {conv1_output_arg, activation_output_arg}, {mul_output_arg});
helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1});
};

auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const std::string activation_key = domain.empty() ? activation_op_type : domain + "." + activation_op_type;
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count[activation_op_type], 1);
EXPECT_EQ(op_to_count[activation_key], 1);
EXPECT_EQ(op_to_count["Add"], 1);
};

NchwcOptimizerTester(build_test_case, check_nchwc_graph);
};

// Verify that the optimizer doesn't add reorders for these activations that
// cannot be fused with a convolution.
std::vector<std::string> activation_op_types{"Relu", "Sigmoid", "Tanh"};
for (auto& activation_op_type : activation_op_types) {
test_case(activation_op_type);
}
// Verify that the optimizer doesn't add reorders for these activations in
// this pattern. Relu/Sigmoid/Tanh are generally fusable with a
// preceding convolution, but not here because the Conv output is consumed
// both by the activation node and directly by the Add node. Gelu/QuickGelu
// are also expected to remain as separate nodes.
test_case("Relu");
test_case("Sigmoid");
test_case("Tanh");
test_case("Gelu", kMSDomain);
test_case("QuickGelu", kMSDomain);
Comment thread
hariharans29 marked this conversation as resolved.
}

TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) {
auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 48, 11, 15});
auto* conv1_output_arg = helper.MakeIntermediate();
auto* activation_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();

helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3});
helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain);
helper.AddConvNode(activation_output_arg, output_arg, {16, 32, 1, 1});
};

auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto& graph = session.GetGraph();
auto op_to_count = CountOpsInGraph(graph);
const std::string activation_key = domain.empty() ? activation_op_type : domain + "." + activation_op_type;

EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count[activation_key], 1);

for (const auto& node : graph.Nodes()) {
if (node.OpType() == "Conv" && node.Domain() == kMSNchwcDomain) {
EXPECT_EQ(node.GetAttributes().count("activation"), 0U)
<< activation_op_type << " should not fuse into a single-consumer NCHWc Conv";
}
}
};

NchwcOptimizerTester(build_test_case, check_nchwc_graph);
};

// Gelu/QuickGelu must remain separate even with a single-consumer Conv input,
// because the NCHWc Conv activation fuse guard only allows a fixed subset of
// activations.
test_case("Gelu", kMSDomain);
test_case("QuickGelu", kMSDomain);
}

TEST(NchwcOptimizerTests, MaxPoolTypeCheck) {
Expand Down
Loading