Skip to content

Commit

Permalink
Enable Pad->Conv(no pads) fusion (#22001)
Browse files Browse the repository at this point in the history
### Description


### Motivation and Context
For some model has pattern Pad -> Conv. If the Conv doesn't have pads
attributes, the Pad can be fused into Conv.
  • Loading branch information
yihonglyu authored Sep 11, 2024
1 parent 20d9464 commit e91ff94
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
12 changes: 6 additions & 6 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ bool VerifyNotCastChild(const Node& child_node) {
return false;
}

// This pass currently assumed that this attribute already exists on the child node
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
return false;
}

return true;
}

void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
std::vector<int64_t> pads(pads_size - 4, 0);
child_node.AddAttribute("pads", pads);
}

auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());

Expand Down Expand Up @@ -162,4 +162,4 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime
48 changes: 48 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,54 @@ TEST_F(GraphTransformationTests, FusePadWithConv) {
}
}

TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-nopadsconv.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
Initializer pads{*pads_proto, graph.ModelPath()};
gsl::span<const int64_t> pads_values = pads.DataAsSpan<int64_t>();
expected_pads.resize(pads_values.size() - 4);

for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["Conv"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "Conv") {
auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints();
ASSERT_EQ(child_pads->size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the Conv node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads->Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

TEST_F(GraphTransformationTests, FusePadWithMaxPool) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool.onnx";

Expand Down
Binary file not shown.

0 comments on commit e91ff94

Please sign in to comment.