diff --git a/include/onnxruntime/core/providers/cann/cann_provider_options.h b/include/onnxruntime/core/providers/cann/cann_provider_options.h index 51b423e68110a..4b33ee77a892e 100644 --- a/include/onnxruntime/core/providers/cann/cann_provider_options.h +++ b/include/onnxruntime/core/providers/cann/cann_provider_options.h @@ -15,6 +15,8 @@ struct OrtCANNProviderOptions { onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena int enable_cann_graph; // Flag indicating if prioritizing the use of // CANN's graph-running capabilities + int enable_cann_subgraph; // Flag indicating whether to generate subgraph + // automaticly int dump_graphs; // Flag indicating if dumping graphs int dump_om_model; // Flag indicating if dumping om model std::string precision_mode; // Operator Precision Mode diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 4bcf71335d15e..06c3628eb301d 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1266,17 +1266,16 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe // the single operator operation mode of CANN if (info_.enable_cann_graph) { std::vector&& unsupported_nodes = SupportONNXModel(graph_viewer); - - if (unsupported_nodes.empty()) { - auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - } else { + if (info_.enable_cann_subgraph && !unsupported_nodes.empty()) { auto partitions = GetSubGraphPartition(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); for (const auto& partition : partitions) { auto sub_graph = GetSubGraph(partition, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } + } else { + auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); } } else { InlinedVector candidates; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc index d1ba7544bc09e..d6cf9fad70ae5 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc @@ -20,6 +20,7 @@ constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "npu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kEnableCannGraph = "enable_cann_graph"; +constexpr const char* kEnableCannSubGraph = "enable_cann_subgraph"; constexpr const char* kDumpGraphs = "dump_graphs"; constexpr const char* kDumpOmModel = "dump_om_model"; constexpr const char* kPrecisionMode = "precision_mode"; @@ -58,6 +59,7 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P cann::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) .AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph) + .AddAssignmentToReference(cann::provider_option_names::kEnableCannSubGraph, info.enable_cann_subgraph) .AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs) .AddAssignmentToReference(cann::provider_option_names::kDumpOmModel, info.dump_om_model) .AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode) @@ -74,6 +76,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, @@ -89,6 +92,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.h b/onnxruntime/core/providers/cann/cann_execution_provider_info.h index 7ac43e9a8ed6f..9c1f9eb03b67e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.h @@ -18,6 +18,7 @@ struct CANNExecutionProviderInfo { size_t npu_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; bool enable_cann_graph{true}; + bool enable_cann_subgraph{false}; bool dump_graphs{false}; bool dump_om_model{true}; std::string precision_mode; diff --git a/onnxruntime/core/providers/cann/cann_provider_factory.cc b/onnxruntime/core/providers/cann/cann_provider_factory.cc index 4a130b9b0ca20..d3dc86f588f1d 100644 --- a/onnxruntime/core/providers/cann/cann_provider_factory.cc +++ b/onnxruntime/core/providers/cann/cann_provider_factory.cc @@ -76,6 +76,7 @@ struct CANN_Provider : Provider { info.npu_mem_limit = params->npu_mem_limit; info.arena_extend_strategy = params->arena_extend_strategy; info.enable_cann_graph = params->enable_cann_graph != 0; + info.enable_cann_subgraph = params->enable_cann_subgraph != 0; info.dump_graphs = params->dump_graphs != 0; info.dump_om_model = params->dump_om_model != 0; info.precision_mode = params->precision_mode; @@ -94,6 +95,7 @@ struct CANN_Provider : Provider { cann_options.npu_mem_limit = internal_options.npu_mem_limit; cann_options.arena_extend_strategy = internal_options.arena_extend_strategy; cann_options.enable_cann_graph = internal_options.enable_cann_graph; + cann_options.enable_cann_subgraph = internal_options.enable_cann_subgraph; cann_options.dump_graphs = internal_options.dump_graphs; cann_options.dump_om_model = internal_options.dump_om_model; cann_options.precision_mode = internal_options.precision_mode; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 41cf8be1d1412..f82cbcf63ca62 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2902,6 +2902,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider options->npu_mem_limit = SIZE_MAX; options->arena_extend_strategy = static_cast(0); options->enable_cann_graph = 1; + options->enable_cann_subgraph = 0; options->dump_graphs = 0; options->dump_om_model = 1; options->default_memory_arena_cfg = nullptr;