Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena
int enable_cann_graph; // Flag indicating if prioritizing the use of
// CANN's graph-running capabilities
int enable_cann_subgraph; // Flag indicating whether to generate subgraph
// automaticly

Check warning on line 19 in include/onnxruntime/core/providers/cann/cann_provider_options.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "automaticly" is a misspelling of "automatically" Raw Output: ./include/onnxruntime/core/providers/cann/cann_provider_options.h:19:62: "automaticly" is a misspelling of "automatically"
int dump_graphs; // Flag indicating if dumping graphs
int dump_om_model; // Flag indicating if dumping om model
std::string precision_mode; // Operator Precision Mode
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1266,17 +1266,16 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
// the single operator operation mode of CANN
if (info_.enable_cann_graph) {
std::vector<NodeIndex>&& unsupported_nodes = SupportONNXModel(graph_viewer);

if (unsupported_nodes.empty()) {
auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
} else {
if (info_.enable_cann_subgraph && !unsupported_nodes.empty()) {
auto partitions = GetSubGraphPartition(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes);

for (const auto& partition : partitions) {
auto sub_graph = GetSubGraph(partition, graph_viewer);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
}
} else {
auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
}
} else {
InlinedVector<NodeIndex> candidates;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ constexpr const char* kDeviceId = "device_id";
constexpr const char* kMemLimit = "npu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kEnableCannGraph = "enable_cann_graph";
constexpr const char* kEnableCannSubGraph = "enable_cann_subgraph";
constexpr const char* kDumpGraphs = "dump_graphs";
constexpr const char* kDumpOmModel = "dump_om_model";
constexpr const char* kPrecisionMode = "precision_mode";
Expand Down Expand Up @@ -58,6 +59,7 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P
cann::provider_option_names::kArenaExtendStrategy,
arena_extend_strategy_mapping, info.arena_extend_strategy)
.AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph)
.AddAssignmentToReference(cann::provider_option_names::kEnableCannSubGraph, info.enable_cann_subgraph)
.AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs)
.AddAssignmentToReference(cann::provider_option_names::kDumpOmModel, info.dump_om_model)
.AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode)
Expand All @@ -74,6 +76,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution
{cann::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
{cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)},
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
{cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)},
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
Expand All @@ -89,6 +92,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid
{cann::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))},
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
{cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)},
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
{cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)},
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct CANNExecutionProviderInfo {
size_t npu_mem_limit{std::numeric_limits<size_t>::max()};
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
bool enable_cann_graph{true};
bool enable_cann_subgraph{false};
bool dump_graphs{false};
bool dump_om_model{true};
std::string precision_mode;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cann/cann_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ struct CANN_Provider : Provider {
info.npu_mem_limit = params->npu_mem_limit;
info.arena_extend_strategy = params->arena_extend_strategy;
info.enable_cann_graph = params->enable_cann_graph != 0;
info.enable_cann_subgraph = params->enable_cann_subgraph != 0;
info.dump_graphs = params->dump_graphs != 0;
info.dump_om_model = params->dump_om_model != 0;
info.precision_mode = params->precision_mode;
Expand All @@ -94,6 +95,7 @@ struct CANN_Provider : Provider {
cann_options.npu_mem_limit = internal_options.npu_mem_limit;
cann_options.arena_extend_strategy = internal_options.arena_extend_strategy;
cann_options.enable_cann_graph = internal_options.enable_cann_graph;
cann_options.enable_cann_subgraph = internal_options.enable_cann_subgraph;
cann_options.dump_graphs = internal_options.dump_graphs;
cann_options.dump_om_model = internal_options.dump_om_model;
cann_options.precision_mode = internal_options.precision_mode;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2902,6 +2902,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider
options->npu_mem_limit = SIZE_MAX;
options->arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(0);
options->enable_cann_graph = 1;
options->enable_cann_subgraph = 0;
options->dump_graphs = 0;
options->dump_om_model = 1;
options->default_memory_arena_cfg = nullptr;
Expand Down
Loading