Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions onnxruntime/core/optimizer/conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
return data_type == actual_data_type;
}

bool ConvFusionDataTypeCheck(const Node& conv_node) {
// TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation
// and the Conv+Add+Relu fusions.
// Assess the support level for the other compatible EPs and if they also
// only support float, remove the EP check altogether.
const std::string_view node_ep = conv_node.GetExecutionProviderType();
if (node_ep == kCudaExecutionProvider || node_ep == kCpuExecutionProvider) {
if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return false;
}
}

return true;
}

class ConvActivation : public NodeSelector {
public:
ConvActivation() = default;
Expand Down Expand Up @@ -74,12 +89,12 @@ class ConvActivation : public NodeSelector {
return false;
};

if (!ConvFusionDataTypeCheck(node)) {
return std::nullopt;
}

// check EP type and activation
if (node_ep == kCudaExecutionProvider) {
if (!HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return std::nullopt;
}

if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) {
return std::nullopt;
}
Expand Down Expand Up @@ -112,6 +127,10 @@ class ConvAddRelu : public NodeSelector {
return std::nullopt;
}

if (!ConvFusionDataTypeCheck(node)) {
return std::nullopt;
}

const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
if (!add_node ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) ||
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,37 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph
}

// Currently the ConvAddRelu fusion is only backed by a float kernel for the
// the CUDA EP.

// When we see the corresponding pattern for the fp16 data type, the fusion
// should not be triggered as there is no kernel to back the fused pattern.

// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of
// data type support in other compatible EPs is still yet to be ascertained.

// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove
// this test.
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
for (auto& node : p_model->MainGraph().Nodes()) {
node.SetExecutionProviderType(kCudaExecutionProvider);
}
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Add"], 1);
ASSERT_EQ(op_to_count["Relu"], 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered)
ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered)
}

// Conv->Add->Relu will be left intact since there is Identity depend on Add
TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx";
Expand Down
43 changes: 43 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/conv_add_relu_fp16.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
:�

X
W
BC"Conv

SY"Relu

C
AS"AddgraphZ
X





Z
W





Z
B



Z
A





b
Y





B