Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
}

if (ConvertNodeLayout(*node)) {
// domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv.
// So, change the OpType to "Conv" for FusedConv.
std::string_view op_type = node->OpType() == "FusedConv" ? "Conv" : node->OpType();

// if already transformed then change the domain to kMSInternalNHWCDomain this way the EP
// knows this op is in the expected format.
if (node->GetAttributeIntDefault("channels_last", 0) == 1) {
SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
SwapNodeOpTypeAndDomain(*api_graph, *node, op_type, kMSInternalNHWCDomain);
// Changing the domain for the node requires creating a new node and replacing the old one
// therefore set the modified flag.
modified = true;
Expand All @@ -175,7 +179,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
// Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation
// for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout
// transformer only converts layout for 0th input, weights should be handled by every EP.
if (node->OpType() == "Resize") {
if (op_type == "Resize") {
// Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case,
// we need to jump a few extra hoops to make sure its inputs are correctly handled.
//
Expand Down Expand Up @@ -205,7 +209,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
}

SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
SwapNodeOpTypeAndDomain(*api_graph, *node, op_type, kMSInternalNHWCDomain);
modified = true;
}
}
Expand Down
Loading