diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 6f1d9d9c6611d..2754eebf75421 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -145,10 +145,14 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid } if (ConvertNodeLayout(*node)) { + // domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv. + // So, change the OpType to "Conv" for FusedConv. + std::string_view op_type = node->OpType() == "FusedConv" ? "Conv" : node->OpType(); + // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP // knows this op is in the expected format. if (node->GetAttributeIntDefault("channels_last", 0) == 1) { - SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); + SwapNodeOpTypeAndDomain(*api_graph, *node, op_type, kMSInternalNHWCDomain); // Changing the domain for the node requires creating a new node and replacing the old one // therefore set the modified flag. modified = true; @@ -175,7 +179,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid // Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation // for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout // transformer only converts layout for 0th input, weights should be handled by every EP. - if (node->OpType() == "Resize") { + if (op_type == "Resize") { // Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case, // we need to jump a few extra hoops to make sure its inputs are correctly handled. // @@ -205,7 +209,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); + SwapNodeOpTypeAndDomain(*api_graph, *node, op_type, kMSInternalNHWCDomain); modified = true; } }