diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index dede1ecc95885..9a440a65d87dd 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -177,10 +177,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC - // CPU EP layout transformation happens later when level 3 transformers are run. - if (params.mode != GraphPartitioner::Mode::kAssignOnly && - current_ep.GetPreferredLayout() == DataLayout::NHWC) { + // Run layout transformation for all EPs. + // For an EP that wants NHWC this will wrap layout sensitive nodes with Transpose nodes first. + // In both NCHW and NHWC EPs the EP specific transpose optimization is run last to optimize + // transposes for nodes assigned to the EP or unassigned nodes. This allows things like the + // EP aware Resize handling to be run. + if (params.mode != GraphPartitioner::Mode::kAssignOnly && params.transform_layout.get()) { for (auto& capability : capabilities) { TryAssignNodes(graph, *capability->sub_graph, ep_type); } diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 59ae4bdf9745d..63b06fdf60536 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -92,7 +92,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function& GetORTLayoutSensitiveOps() { static std::unordered_set ort_layout_sensitive_ops = []() { const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); std::unordered_set ort_specific_ops = - { "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. - // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and - // onnx_layout_transformation::HandleResize is used. - , - "Resize" -#endif - }; + {"FusedConv", + "QLinearAveragePool", + "QLinearGlobalAveragePool"}; ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); return ort_specific_ops; @@ -42,6 +34,24 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { return ort_layout_sensitive_ops; } +const std::unordered_set GetEPLayoutSensitiveOps(const IExecutionProvider& execution_provider) { + std::unordered_set layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + const auto& ep = execution_provider.Type(); + + // EPs where the Resize implementation only handles one layout - either NCHW or NHWC. The ONNX spec for Resize is + // not layout specific. We assume if the EP has a layout sensitive Resize it only handles its preferred layout, + // so when doing layout transformation we consider the Resize to be layout sensitive and can wrap the Resize + // in Transpose nodes to convert to the preferred layout, but we can't push any Transpose operations through the + // Resize in the general transpose optimization. + const auto& layout_sensitive_eps = EPsWithLayoutSensitiveResize(); + if (layout_sensitive_eps.find(ep) != layout_sensitive_eps.end()) { + layout_sensitive_ops.insert("Resize"); + } + + return layout_sensitive_ops; +} + // Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. static CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, @@ -64,93 +74,104 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid const DebugGraphFn& debug_graph_fn) { // We pass in nullptr for the new_node_ep param as new nodes will be assigned by the graph partitioner after // TransformLayoutForEP returns. - // sub graph recurse will be added later. + // sub graph recurse will be added later auto api_graph = MakeApiGraph(graph, cpu_allocator, /*new_node_ep*/ nullptr); - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + const auto& layout_sensitive_ops = GetEPLayoutSensitiveOps(execution_provider); - // to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. - for (auto& node : api_graph->Nodes()) { - if (layout_sensitive_ops.count(node->OpType())) { - if (node->GetExecutionProviderType() != execution_provider.Type()) { - continue; - } + CostCheckFn cost_check; - auto domain = node->Domain(); - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } + // if converting to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. + if (execution_provider.GetPreferredLayout() == DataLayout::NHWC) { + for (auto& node : api_graph->Nodes()) { + if (layout_sensitive_ops.count(node->OpType())) { + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } - // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP - // knows this op is in the expected format. - if (node->GetAttributeIntDefault("channels_last", 0) == 1) { - SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); - // Changing the domain for the node requires creating a new node and replacing the old one - // therefore set the modified flag. - modified = true; - continue; - } + auto domain = node->Domain(); + // Skip if domain is incorrect + if (domain != kOnnxDomain && domain != kMSDomain) { + continue; + } - // Skip if unknown rank - auto shape = api_graph->GetValueInfo(node->Inputs()[0])->Shape(); - if (!shape.has_value()) { - continue; - } + // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP + // knows this op is in the expected format. + if (node->GetAttributeIntDefault("channels_last", 0) == 1) { + SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); + // Changing the domain for the node requires creating a new node and replacing the old one + // therefore set the modified flag. + modified = true; + continue; + } - // Convert to channels last - size_t rank = shape->size(); + // Skip if unknown rank + auto shape = api_graph->GetValueInfo(node->Inputs()[0])->Shape(); + if (!shape.has_value()) { + continue; + } - bool has_channel_last_attr = node->GetAttributeInt("channels_last").has_value() ? true : false; - if (has_channel_last_attr) { - node->SetAttributeInt("channels_last", 1); - } + // Convert to channels last + size_t rank = shape->size(); - auto input_perm = onnx_transpose_optimization::ChannelFirstToLastPerm(rank); - auto output_perm = onnx_transpose_optimization::ChannelLastToFirstPerm(rank); - - // Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation - // for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout - // transformer only converts layout for 0th input, weights should be handled by every EP. - if (node->OpType() == "Resize") { - // Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case, - // we need to jump a few extra hoops to make sure its inputs are correctly handled. - // - // Current code skips layout conversion for ROI because it needs special handling as ROI size is 2*rank. - // Enable passing in ROI for layout conversion when an EP which supports ROI starts using layout transformer. - // NNAPI which currently uses layout transformer does not support it. - std::vector*> input_perms{&input_perm, nullptr}; - for (size_t i = 2; i < node->Inputs().size(); i++) { - auto constant = api_graph->GetConstant(node->Inputs()[i]); - if (constant != nullptr && constant->Data().size() > 0) { - input_perms.push_back(&input_perm); - } else { - // TODO: Fix inconsistency. We should Transpose the non-const inputs so that the result of our changes - // is consistent - all layout specific inputs are in NHWC format when we're done. - // This may need to check the opset to see if it's safe so that an empty non-constant input doesn't - // have an invalid Transpose added to it. - // Caveat: Typically `scales` and `sizes` are constants so this may not happen in a production model. - input_perms.push_back(nullptr); + bool has_channel_last_attr = node->GetAttributeInt("channels_last").has_value() ? true : false; + if (has_channel_last_attr) { + node->SetAttributeInt("channels_last", 1); + } + + auto input_perm = onnx_transpose_optimization::ChannelFirstToLastPerm(rank); + auto output_perm = onnx_transpose_optimization::ChannelLastToFirstPerm(rank); + + // Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation + // for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout + // transformer only converts layout for 0th input, weights should be handled by every EP. + if (node->OpType() == "Resize") { + // Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case, + // we need to jump a few extra hoops to make sure its inputs are correctly handled. + // + // Current code skips layout conversion for ROI because it needs special handling as ROI size is 2*rank. + // Enable passing in ROI for layout conversion when an EP which supports ROI starts using layout transformer. + // NNAPI which currently uses layout transformer does not support it. + std::vector*> input_perms{&input_perm, nullptr}; + for (size_t i = 2; i < node->Inputs().size(); i++) { + auto constant = api_graph->GetConstant(node->Inputs()[i]); + if (constant != nullptr && constant->Data().size() > 0) { + input_perms.push_back(&input_perm); + } else { + // TODO: Fix inconsistency. We should Transpose the non-const inputs so that the result of our changes + // is consistent - all layout specific inputs are in NHWC format when we're done. + // This may need to check the opset to see if it's safe so that an empty non-constant input doesn't + // have an invalid Transpose added to it. + // Caveat: Typically `scales` and `sizes` are constants so this may not happen in a production model. + input_perms.push_back(nullptr); + } } + WrapTransposesAroundNode(*api_graph, *node, input_perms, {&output_perm}); + } else { + WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - WrapTransposesAroundNode(*api_graph, *node, input_perms, {&output_perm}); - } else { - WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); + + // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. + SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); + modified = true; } + } - // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. - SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); - modified = true; + cost_check = PostLayoutTransformCostCheck; + + // debug the changes made inserting Transpose nodes around layout sensitive ops. + if (debug_graph_fn) { + debug_graph_fn(graph); } - } - // debug the changes made inserting Transpose nodes around layout sensitive ops. - if (debug_graph_fn) { - debug_graph_fn(graph); + } else { + // layout is fine for the EP but we still want to run the transpose optimizer one more time for the EP specific + // Transpose -> Resize logic. + cost_check = OrtEPCostCheck; } const auto max_node_idx = graph.MaxNodeIndex(); - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, execution_provider.Type(), - PostLayoutTransformCostCheck); + OptimizeResult result = + onnx_transpose_optimization::Optimize(*api_graph, execution_provider.Type(), cost_check, OrtHandlers()); if (result.error_msg) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Layout/Transpose optimization for ", execution_provider.Type(), diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index a54903a036840..f657f2950a657 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -926,18 +926,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con node.SetInput(i, gather_output); } -static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) - // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize - // in ORT builds with CUDA enabled. - // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. - // The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled for QNN builds. - // - // TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" (i.e.) aligning with the - // generic nature of the ONNX spec. - // See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize kernel. - return false; -#else +bool HandleResize([[maybe_unused]] HandlerArgs& args) { auto inputs = args.node.Inputs(); int64_t rank_int = gsl::narrow_cast(args.perm.size()); @@ -963,7 +952,6 @@ static bool HandleResize([[maybe_unused]] HandlerArgs& args) { TransposeOutputs(args.ctx, args.node, args.perm); return true; -#endif } constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 131ff6c6ef0c6..59e5101bba691 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -67,6 +67,7 @@ bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_ // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index b30c94d7b3e40..2accfb0091112 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -838,5 +838,4 @@ onnxruntime::Graph& GraphFromApiGraph(onnx_transpose_optimization::api::GraphRef onnxruntime::Node& NodeFromApiNode(onnx_transpose_optimization::api::NodeRef& node) { return static_cast(node).Node(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index 8378d7b22e537..12dd1728550dd 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -5,12 +5,29 @@ #include #include "core/graph/constants.h" +#include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" using namespace onnx_transpose_optimization; namespace onnxruntime { +static bool EPAwareHandleResize(HandlerArgs& args) { + // Whilst Resize is not technically layout sensitive, some execution providers implement handling for only one + // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's not being handled + // by an EP that only supports a single layout. + const auto& layout_sensitive_eps = EPsWithLayoutSensitiveResize(); + + const auto& provider = args.ctx.provider_type; + if (provider.empty() || layout_sensitive_eps.find(provider) != layout_sensitive_eps.end()) { + return false; + } + + return HandleResize(args); +} + +constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; + static bool HandleQLinearConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } @@ -86,9 +103,17 @@ static bool HandleMaxPool(HandlerArgs& args) { } constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; + constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; +const HandlerMap& OrtHandlers() { + static const HandlerMap extended_handler_map{ + {"Resize", ep_aware_resize_handler}, + }; + + return extended_handler_map; +} // ORT contrib ops and special cased ONNX ops where we have EP specific handling const HandlerMap& OrtExtendedHandlers() { static const HandlerMap extended_handler_map = []() { @@ -104,12 +129,37 @@ const HandlerMap& OrtExtendedHandlers() { {"com.microsoft.QLinearSigmoid", node_1_inp_handler}, }; + const auto& base_handlers = OrtHandlers(); + std::for_each(base_handlers.begin(), base_handlers.end(), [&map](const auto& entry) { map.insert(entry); }); + return map; }(); return extended_handler_map; } +// EPs that require Resize to stay in the current layout. +// The CUDA Resize kernel requires that the input is NCHW +// The ROCm EP is generated from the CUDA EP kernel so the same applies to it. +// TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" +// i.e. aligning with the generic nature of the ONNX spec. +// See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize. +// The QNN EP requires the Resize to remain in NHWC once the layout transformer makes that adjustment +// and moves the node to the kMSInternalNHWCDomain domain. We need it to be in this list so that the layout +// transformation inserts Transpose nodes around the Resize to convert from NCWH to NHWC. As there is no handler for +// the replacement Resize node in the kMSInternalNHWCDomain domain we will not push any Transpose nodes through it +// later. +const std::unordered_set EPsWithLayoutSensitiveResize() { + static std::unordered_set eps = { + kCudaExecutionProvider, + kRocmExecutionProvider, + kQnnExecutionProvider, + onnxruntime::utils::kInternalTestingExecutionProvider, // for testing the behavior + }; + + return eps; +} + CostCheckResult OrtEPCostCheck(const api::GraphRef& graph, const api::NodeRef& node, const std::vector& /*perm*/, const std::unordered_set& /*outputs_leading_to_transpose*/) { diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h index 0a5dbd6d13d06..10203cd8b0eab 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h @@ -7,6 +7,12 @@ #include "core/optimizer/transpose_optimization/onnx_transpose_optimization.h" namespace onnxruntime { +/// +/// Get the handlers for basic transpose optimization that are aware of any EP specific limitations. +/// +/// HandlerMap +const onnx_transpose_optimization::HandlerMap& OrtHandlers(); + /// /// Get the extended handlers for ORT specific transpose optimization. /// These include handlers for contrib ops, and where we have an NHWC version of a layout sensitive op. @@ -15,6 +21,13 @@ namespace onnxruntime { /// HandlerMap const onnx_transpose_optimization::HandlerMap& OrtExtendedHandlers(); +/// +/// Return set of execution providers that are known to have a layout sensitive implementation of Resize. +/// If the Resize is layout sensitive we do not push a Transpose through it.as that would change the layout. +/// +/// Set of execution provider names. +const std::unordered_set EPsWithLayoutSensitiveResize(); + /// /// Cost check function for transpose optimizer that takes into account implementation details of the /// ORT execution provider kernels. diff --git a/onnxruntime/core/optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer.cc index 5f17dc14657dd..a6a2a017c3a2f 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer.cc @@ -20,7 +20,8 @@ Status TransposeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_lev const logging::Logger& logger) const { auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr); + OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, + OrtHandlers()); if (result.error_msg) { // currently onnx_layout_transformation::Optimize only fails if we hit an unsupported opset. diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 127c37bd84d0f..5cc03839faaab 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1282,7 +1282,7 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe continue; const auto& node = *p_node; - if (!node.GetExecutionProviderType().empty()) { + if (!node.GetExecutionProviderType().empty() && node.GetExecutionProviderType() != kCannExecutionProvider) { continue; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index aa60db4d07222..842e757305d7b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2471,7 +2471,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - if (!node.GetExecutionProviderType().empty()) { + if (!node.GetExecutionProviderType().empty() && node.GetExecutionProviderType() != kCudaExecutionProvider) { continue; } diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9401de64269b9..ae9e2bbd47780 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2278,7 +2278,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - if (!node.GetExecutionProviderType().empty()) { + if (!node.GetExecutionProviderType().empty() && node.GetExecutionProviderType() != kRocmExecutionProvider) { continue; } diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc index fb9ce580ea2dc..2a36ee9c0313d 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc @@ -79,7 +79,7 @@ SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, continue; const auto& node = *p_node; - if (!node.GetExecutionProviderType().empty()) { + if (!node.GetExecutionProviderType().empty() && node.GetExecutionProviderType() != kSnpeExecutionProvider) { continue; } diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index f8da4e895913a..00436c286192e 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -22,6 +22,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" + using namespace ONNX_NAMESPACE; namespace onnxruntime { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index be399ce8db60d..423fd2e7072ac 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2643,12 +2643,12 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 18); // disable TransposeOptimizer for simplicity + 18); TransformerTester(build_test_case, check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 19); // disable TransposeOptimizer for simplicity + 19); }; test_case({1, 13, 13, 23}, {0, 2, 3, 1}); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 143a1eb8bec59..eacd3e56a77af 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4425,12 +4425,13 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { testing::ContainerEq(fetches[0].Get().DataAsSpan())); } +// These tests uses internal testing EP with static kernels which requires a full build, +// and the NHWC Conv with requires contrib ops +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + // Test a Transpose node followed by a Reshape that is logically equivalent to an Transpose can be merged. // The test graph was extracted from a model we were trying to use with the QNN EP. TEST(TransposeOptimizerTests, QnnTransposeReshape) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.onnx"); @@ -4473,13 +4474,9 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; } -#endif } TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx"); @@ -4516,7 +4513,48 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; } -#endif } + +// Validate handling for EP with layout specific Resize that prefers NHWC +TEST(TransposeOptimizerTests, QnnResizeOpset11) { + Status status; + auto model_uri = ORT_TSTR("testdata/nhwc_resize_scales_opset11.onnx"); + + SessionOptions so; + // Uncomment to debug + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + if (node.OpType() == "Resize") { + EXPECT_EQ(node.Domain(), kMSInternalNHWCDomain) << "Resize was not converted to NHWC layout"; + } + } + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2) << "Resize should have been wrapped in 2 Transpose nodes to convert to NHWC"; + + // And the post-Resize Transpose should have been pushed all the way to the end + GraphViewer viewer(graph); + EXPECT_EQ(graph.GetNode(viewer.GetNodesInTopologicalOrder().back())->OpType(), "Transpose"); +} +#endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + } // namespace test } // namespace onnxruntime