Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion onnxruntime/core/optimizer/nchwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_output_channels = (output_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 349 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

bool do_reorder_input = true;
bool reorder_filter_OIHWBo = false;
Expand Down Expand Up @@ -379,7 +379,7 @@
if ((input_channels % channel_alignment) != 0) {
return;
}
filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 382 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size
}
}

Expand Down Expand Up @@ -881,12 +881,21 @@

const bool can_fuse_activation = (node.OpType() == "Relu") ||
(node.OpType() == "Sigmoid") ||
(node.OpType() == "Tanh");
(node.OpType() == "Tanh") ||
(node.OpType() == "HardSigmoid");
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
can_fuse_activation &&
(nchwc_input->starting_original_uses_ == 1) &&
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
nchwc_node.AddAttribute("activation", node.OpType());
if (node.OpType() == "HardSigmoid") {
const auto* alpha_attr = graph_utils::GetNodeAttribute(node, "alpha");
const auto* beta_attr = graph_utils::GetNodeAttribute(node, "beta");
InlinedVector<float> activation_params{
alpha_attr == nullptr ? 0.2f : alpha_attr->f(),
beta_attr == nullptr ? 0.5f : beta_attr->f()};
nchwc_node.AddAttribute("activation_params", activation_params);
}
FuseNchwcArgument(node, *nchwc_input);
removed_nodes_.push_front(node.Index());
} else {
Expand Down Expand Up @@ -969,7 +978,7 @@
bn_B.sub(bn_mean);

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 981 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

InlinedVector<float> padded_buffer(gsl::narrow<size_t>(nchwc_channels));

Expand Down Expand Up @@ -1265,6 +1274,7 @@
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) {
TransformConcat(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6, 22}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) ||
Expand Down
71 changes: 69 additions & 2 deletions onnxruntime/test/optimizer/nchwc_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1435,18 +1435,85 @@ TEST(NchwcOptimizerTests, Activation) {
};

// Verify that the optimizer doesn't add reorders for these activations in
// this pattern. Relu/Sigmoid/Tanh are generally fusable with a
// this pattern. Relu/Sigmoid/Tanh/HardSigmoid are generally fusable with a
// preceding convolution, but not here because the Conv output is consumed
// both by the activation node and directly by the Add node. Gelu/QuickGelu
// are also expected to remain as separate nodes.
test_case("Relu");
test_case("Sigmoid");
test_case("Tanh");
test_case("HardSigmoid");
test_case("Gelu", kMSDomain);
test_case("QuickGelu", kMSDomain);
}

TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) {
TEST(NchwcOptimizerTests, ActivationSingleConsumerConvFusion) {
constexpr float kHardSigmoidAlpha = 0.125f;
constexpr float kHardSigmoidBeta = 0.625f;

auto test_case = [&](const std::string& activation_op_type) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 48, 11, 15});
auto* conv1_output_arg = helper.MakeIntermediate();
auto* activation_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();

helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3});
auto& activation_node = helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg});
if (activation_op_type == "HardSigmoid") {
activation_node.AddAttribute("alpha", kHardSigmoidAlpha);
activation_node.AddAttribute("beta", kHardSigmoidBeta);
}
helper.AddConvNode(activation_output_arg, output_arg, {16, 32, 1, 1});
};

auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto& graph = session.GetGraph();
auto op_to_count = CountOpsInGraph(graph);

EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count[activation_op_type], 0);

size_t fused_conv_count = 0;
for (const auto& node : graph.Nodes()) {
if (node.OpType() != "Conv" || node.Domain() != kMSNchwcDomain) {
continue;
}

const auto& attributes = node.GetAttributes();
auto activation_it = attributes.find("activation");
if (activation_it == attributes.end()) {
continue;
}

fused_conv_count++;
EXPECT_EQ(activation_it->second.s(), activation_op_type);

auto activation_params_it = attributes.find("activation_params");
if (activation_op_type == "HardSigmoid") {
ASSERT_NE(activation_params_it, attributes.end());
ASSERT_EQ(activation_params_it->second.floats_size(), 2);
EXPECT_FLOAT_EQ(activation_params_it->second.floats(0), kHardSigmoidAlpha);
EXPECT_FLOAT_EQ(activation_params_it->second.floats(1), kHardSigmoidBeta);
} else {
EXPECT_EQ(activation_params_it, attributes.end());
}
}

EXPECT_EQ(fused_conv_count, 1U);
};

NchwcOptimizerTester(build_test_case, check_nchwc_graph);
};

for (const auto& activation_op_type : {"Relu", "Sigmoid", "Tanh", "HardSigmoid"}) {
test_case(activation_op_type);
}
}

TEST(NchwcOptimizerTests, ActivationSingleConsumerConvNoFusion) {
auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 48, 11, 15});
Expand Down
Loading