diff --git a/docs/Optimizer_Layering_Annotations.md b/docs/Optimizer_Layering_Annotations.md new file mode 100644 index 0000000000000..a268bd8fbe84f --- /dev/null +++ b/docs/Optimizer_Layering_Annotations.md @@ -0,0 +1,130 @@ +# Optimizer Layering Annotations + +## Overview + +Layering annotations are per-node metadata strings that guide graph partitioning by indicating which execution provider (EP) layer a node belongs to. They are loaded from the ONNX model's `NodeProto` metadata (key `"layer_ann"`) and consumed during the partitioning phase to influence EP assignment. + +## Execution Pipeline + +Graph optimizers run in ordered levels: + +``` +Level 0 (Basic) ─► Level 1 (Extended) ─► Partitioning ─► Level 2+ (Layout, etc.) +``` + +1. **Level 0 and Level 1** optimizers run **before** partitioning. At this point, layering annotations are present on nodes and must be preserved through any graph transformations. +2. **Partitioning** reads the annotations to assign nodes to execution providers. +3. After partitioning, `Graph::RemoveAllLayeringAnnotations()` clears all annotations. +4. **Level 2, 3, and 4** optimizers run **after** annotations have been cleared. They do not need to handle annotations. + +**Key rule: Only Level 1 (and Level 0) optimizers need to propagate layering annotations.** + +## Why Propagation Matters + +When an optimizer replaces, fuses, or decomposes nodes, the original annotated node is removed and new nodes are created. If the new nodes do not carry the original annotation, partitioning loses the assignment hint for that subgraph, potentially causing incorrect EP placement. + +## How to Propagate Annotations + +### Preferred: Use the `AddNode` Overload with `annotation_source` + +`Graph::AddNode` provides overloads that accept a `const Node& annotation_source` parameter. The new node automatically inherits the layering annotation from the source node. + +```cpp +// Instead of: +Node& new_node = graph.AddNode(name, op_type, description, inputs, outputs); +// Missing annotation propagation! + +// Use: +Node& new_node = graph.AddNode(name, op_type, description, inputs, outputs, + original_node); // annotation_source +``` + +All standard `AddNode` signatures have a corresponding `annotation_source` variant: + +```cpp +// With const NodeAttributes* +Node& AddNode(name, op_type, description, + gsl::span inputs, + gsl::span outputs, + const Node& annotation_source, + const NodeAttributes* attributes = nullptr, + const std::string& domain = kOnnxDomain); + +// With NodeAttributes&& +Node& AddNode(name, op_type, description, + gsl::span inputs, + gsl::span outputs, + const Node& annotation_source, + NodeAttributes&& attributes, + const std::string& domain = kOnnxDomain); + +// initializer_list variants also available +``` + +### Legacy: `DuplicateNodeAnnotation` + +The utility function `optimizer_utils::DuplicateNodeAnnotation(src, dst)` copies annotations between existing nodes. This is still used when the annotation source is conditional (e.g., when the source node pointer may be null). Prefer the `AddNode` overload for unconditional propagation. + +### Automatic Propagation + +`Graph::AddNode(const Node& other)` — the copy overload used for duplicating nodes — automatically copies annotations. No additional action is needed when duplicating a node via this overload. + +## Post-Partitioning: Propagating EP Assignments + +Although Level 2+ optimizers do not deal with layering annotations directly (they have been cleared), they must still propagate **execution provider (EP) assignments**. EP assignments are the downstream result of the annotation-driven partitioning step. After partitioning, each node carries an EP assignment (e.g., `CUDAExecutionProvider`, `CPUExecutionProvider`) that determines where the node's kernel runs. + +When a Level 2+ optimizer creates new nodes that replace or derive from existing ones, it must copy the EP assignment from the source node: + +```cpp +Node& new_node = graph.AddNode(name, op_type, description, inputs, outputs); +new_node.SetExecutionProviderType(original_node.GetExecutionProviderType()); +``` + +Failing to propagate the EP assignment causes the new node to fall back to the default provider (typically CPU), silently breaking the intended placement and potentially degrading performance or correctness. This requirement predates the layering annotation feature and applies to all optimizers that run after partitioning. + +> **Note:** The `AddNode` overload with `annotation_source` propagates both the layering annotation *and* nothing else — EP assignment is still set separately. Layering annotations and EP assignments serve different stages of the pipeline and are managed independently. + +## When You Do NOT Need to Propagate Annotations + +- **Level 2+ optimizers** — annotations have already been consumed and cleared (but EP assignments must still be propagated, see above). +- **Training optimizers** — training runs after partitioning. +- **Optimizers that only remove nodes** (e.g., identity elimination) — no new nodes are created. +- **Optimizers that modify nodes in-place** — the annotation remains on the existing node. + +## Examples + +### Fusion (replacing multiple nodes with one) + +```cpp +// GeluFusion: fusing Div + Erf + Add + Mul + Mul into a single Gelu +Node& gelu_node = graph.AddNode( + graph.GenerateNodeName("Gelu"), + "Gelu", "fused Gelu subgraphs", + {gelu_input}, {gelu_output}, + div_node); // propagate annotation from the root matched node +``` + +### Decomposition (replacing one node with many) + +```cpp +// STFT decomposition: each new node inherits from the original STFT node +auto [reshape_node, reshape_out] = AddNode(graph, "Reshape", ep, inputs, &stft); +auto [conv_node, conv_out] = AddNode(graph, "Conv", ep, conv_inputs, &stft); +auto [concat_node, concat_out] = AddNode(graph, "Concat", ep, concat_inputs, &stft); +``` + +### Conditional source (use DuplicateNodeAnnotation) + +```cpp +Node& q_node = graph.AddNode(...); +if (src_node) { + optimizer_utils::DuplicateNodeAnnotation(*src_node, q_node); +} +``` + +## Checklist for New Level 1 Optimizers + +1. Identify the "source" node whose annotation should propagate (typically the root of the matched pattern). +2. For every `graph.AddNode(...)` call that creates a replacement node, use the `annotation_source` overload. +3. If the source is conditional (may be null), use `optimizer_utils::DuplicateNodeAnnotation` after the `AddNode` call. +4. Test with an annotated model to verify annotations survive the transformation. diff --git a/include/onnxruntime/core/framework/resource_accountant.h b/include/onnxruntime/core/framework/resource_accountant.h index b072e27816463..7bb5a993d140b 100644 --- a/include/onnxruntime/core/framework/resource_accountant.h +++ b/include/onnxruntime/core/framework/resource_accountant.h @@ -45,18 +45,31 @@ class IResourceAccountant { virtual ResourceCount GetConsumedAmount() const = 0; virtual void AddConsumedAmount(const ResourceCount& amount) = 0; virtual void RemoveConsumedAmount(const ResourceCount& amount) = 0; - virtual ResourceCount ComputeResourceCount(const Node& node) const = 0; + virtual ResourceCount ComputeResourceCount(const Node& node) = 0; std::optional GetThreshold() const { return threshold_; } + void SetThreshold(const ResourceCount& threshold) { + threshold_ = threshold; + } + void SetStopAssignment() noexcept { stop_assignment_ = true; } bool IsStopIssued() const noexcept { return stop_assignment_; } + // Called before each GetCapability pass to discard pending weight tracking + // from a previous (discarded) pass. Default no-op for stats-based accountants. + virtual void ResetPendingWeights() {} + + // Called when a node's cost is committed (AccountForNode/AccountForAllNodes). + // Moves the node's pending weights into the committed set so they persist + // across GetCapability passes. Default no-op for stats-based accountants. + virtual void CommitWeightsForNode(size_t /*node_index*/) {} + static std::string MakeUniqueNodeName(const Node& node); private: @@ -114,11 +127,6 @@ class NodeStatsRecorder { void DumpStats(const std::filesystem::path& model_path) const; - [[nodiscard]] static Status CreateAccountants( - const ConfigOptions& config_options, - const std::filesystem::path& model_path, - std::optional& acc_map); - private: void DumpStats(std::ostream& os) const; @@ -126,4 +134,9 @@ class NodeStatsRecorder { std::unique_ptr impl_; }; +Status CreateAccountants( + const ConfigOptions& config_options, + const std::filesystem::path& model_path, + std::optional& acc_map); + } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 58473a79ddaa6..c5351bc5dfef7 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -174,7 +174,14 @@ class Node { */ void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; } + void SetLayeringAnnotation(std::string annotation) { layering_annotation_ = std::move(annotation); } + + const std::string& GetLayeringAnnotation() const noexcept { return layering_annotation_; } + + const Graph* GetContainingGraph() const noexcept { return graph_; } + #if !defined(ORT_MINIMAL_BUILD) + /** Gets the Node's OpSchema. @remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */ const ONNX_NAMESPACE::OpSchema* Op() const noexcept { return op_; } @@ -256,6 +263,13 @@ class Node { #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + + // Make sure that the annotation does not occupy memory after partitioning is done. + void ClearLayeringAnnotation() { + std::string t; + layering_annotation_.swap(t); + } + /** Gets a modifiable count of arguments for each of the Node's explicit inputs. @todo This should be removed in favor of a method that updates the input args and the count. Currently these operations are separate which is not a good setup. */ @@ -685,6 +699,8 @@ class Node { // Graph instances for subgraphs that are owned by this Node std::vector> subgraphs_; + std::string layering_annotation_; + // Can be saved? The node cannot be saved anymore if removable attributes have been cleared. bool can_be_saved_; }; @@ -1044,6 +1060,41 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi gsl::span output_args, NodeAttributes&& attributes, const std::string& domain = kOnnxDomain); + + /** Add a Node to this Graph, propagating the layering annotation from an existing node. + This is the preferred way to create new nodes in Level 1 (pre-partitioning) graph optimizers. + The new node automatically inherits the layering annotation from @p annotation_source, which + ensures correct layer-based partitioning when annotations are present. + @param name The Node name. Must be unique in this Graph. + @param op_type The operator type. e.g. ONNX operator name. + @param description Arbitrary description of the Node. + @param input_args The explicit inputs to this Node. + @param output_args The outputs from this Node. + @param annotation_source The node from which to inherit the layering annotation. + @param attributes Optional NodeAttributes to add. + @param domain The domain for the op_type. + @returns Reference to the new Node. + @remarks Use this overload in Level 1 optimizers that create nodes replacing or derived from + existing annotated nodes. See docs/Optimizer_Layering_Annotations.md for details. + */ + Node& AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + gsl::span input_args, + gsl::span output_args, + const Node& annotation_source, + const NodeAttributes* attributes = nullptr, + const std::string& domain = kOnnxDomain); + + Node& AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + gsl::span input_args, + gsl::span output_args, + const Node& annotation_source, + NodeAttributes&& attributes, + const std::string& domain = kOnnxDomain); + Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, @@ -1057,6 +1108,21 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi attributes, domain); } + Node& AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + std::initializer_list input_args, + std::initializer_list output_args, + const Node& annotation_source, + const NodeAttributes* attributes = nullptr, + const std::string& domain = kOnnxDomain) { + return AddNode(name, op_type, description, + AsSpan(input_args), + AsSpan(output_args), + annotation_source, + attributes, domain); + } + Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, @@ -1070,16 +1136,46 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi attributes, domain); } + Node& AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + gsl::span input_args, + std::initializer_list output_args, + const Node& annotation_source, + const NodeAttributes* attributes = nullptr, + const std::string& domain = kOnnxDomain) { + return AddNode(name, op_type, description, + input_args, + AsSpan(output_args), + annotation_source, + attributes, domain); + } + + Node& AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + std::initializer_list input_args, + gsl::span output_args, + const NodeAttributes* attributes = nullptr, + const std::string& domain = kOnnxDomain) { + return AddNode(name, op_type, description, + AsSpan(input_args), + output_args, + attributes, domain); + } + Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, std::initializer_list input_args, gsl::span output_args, + const Node& annotation_source, const NodeAttributes* attributes = nullptr, const std::string& domain = kOnnxDomain) { return AddNode(name, op_type, description, AsSpan(input_args), output_args, + annotation_source, attributes, domain); } @@ -1322,10 +1418,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi The Graph needs to be Resolve()d after this call. @param func_to_inline + @param parent_annotation. Annotation inherited from the parent node that is being inlined. @returns Status indicating success or providing an error message. */ - Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline); + Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline, + const std::string& parent_annotation); /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will be used as a GraphProto attribute in another Node. @@ -1569,6 +1667,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // compiled model during partitioning, leaving them unused in the ORT Graph. To allow the memory to be freed // we need to manually run the cleanup that would usually happen as part of Graph::Resolve. Status RemovedUnusedInitializersOrtFormat(); + + // This examines all the nodes and removes any annotations that are only used for layering. + // This potentially saves memory. + Status RemoveAllLayeringAnnotations(); + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // This friendship relationship should only be used to call Graph::Graph and diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index 8ef4fdb66e1e6..54e878761ba87 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -86,18 +86,32 @@ struct IndexedSubGraph { // Should call IsAccountingEnabled() first // Takes the previously computed ResourceCount for the node - // (usually during GetCapabiilty()) + // (usually during GetCapability()) // if present and adds it to the consumed amount void AccountForNode(size_t cost_index) const { assert(cost_index < nodes_costs.size()); resource_accountant->AddConsumedAmount(nodes_costs[cost_index]); + resource_accountant->CommitWeightsForNode(nodes[cost_index]); } - // This computes and accounts for the resource cost for the node that just - // been fused from other nodes, and the EP did not had a chance to compute the costs. - void ComputeAndAccountForNode(const Node& node) const { + // Accounts for all constituent nodes by summing their pre-stored costs. + // Use this when fusing nodes into a single node so the total cost + // reflects what was computed during GetCapability() (with correct + // cross-node weight deduplication already applied). + void AccountForAllNodes() const { assert(resource_accountant != nullptr); - resource_accountant->AddConsumedAmount(resource_accountant->ComputeResourceCount(node)); + for (size_t i = 0; i < nodes_costs.size(); ++i) { + resource_accountant->AddConsumedAmount(nodes_costs[i]); + resource_accountant->CommitWeightsForNode(nodes[i]); + } + } + + // Accounts for a node given its index and a pre-computed resource cost. + // Use this when the cost was computed externally (e.g. for a fused node). + void AccountForNode(NodeIndex node_index, const ResourceCount& resource_count) const { + assert(resource_accountant != nullptr); + resource_accountant->AddConsumedAmount(resource_count); + resource_accountant->CommitWeightsForNode(node_index); } void SetAccountant(IResourceAccountant* res_accountant) { diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index f0a99bc11c8b3..44ff0256c33fe 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -325,13 +325,33 @@ static const char* const kOrtSessionOptionsCollectNodeMemoryStatsToFile = "sessi /// This is a composite CSV setting formatted as "memory limit in kb,file name for collected stats" /// "limit > 0": enables Capacity Aware Partitioning for Cuda EP. `limit` is optional and when absent /// the provider may attempt to figure out the memory available automatically. +/// The setting with no pre-recorded stats is expected to look like: "limit > 0,". +/// In this case, the EP will calculate memory using the initializers referenced by the node. +/// This enables an ad-hoc and flexible scenarios with no pre-recorded stats, but may be less accurate. /// The setting with no limit is expected to look like: ",file name for collected stats" -/// The EP will place nodes on device "file name" : +/// Finally a setting with both limit and pre-recorded stats absent can contain a single comma: ",". +/// The EP will attempt to place nodes on device (currently only CUDA is supported) : /// this file is expected to be found at the same folder with the model. The file contains /// pre-recorded stats collected when running with kOrtSessionOptionsCollectNodeMemoryStatsToFile enforce (see above) static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings = "session.resource_cuda_partitioning_settings"; +/// +/// This is a setting that contains string annotations or annotation prefixes to be matched +/// against individual nodes metadata entry 'layer_ann' to guide layer assignment during partitioning. +/// The value is a semicolon separated list of strings or string prefixes per device. +/// Format: device1(annotation1, annotation2, ...); device2(annotation1, =annotation3, ...);... +/// Where: +/// - device1, device2, ... are the recognized device names to be matched against EPs configured in +/// the given session. +/// - annotation1, annotation2, ... are annotation prefixes to be matched against node annotations. Any +/// node annotation that starts with one of these prefixes will be matched. +/// - =annotation3 indicates an exact match for annotation3. Only node annotations that are exactly +/// equal to 'annotation3' will be matched. +/// TODO: add a list of recognized devices here. +/// +static const char* const kOrtSessionOptionsLayerAssignmentSettings = "session.layer_assignment_settings"; + // Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 9cb2111670ba6..cc65142318d02 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/framework/kernel_lookup.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" +#include "core/framework/layering_annotations.h" #include "core/framework/resource_accountant.h" #include "core/graph/function.h" #include "core/graph/function_utils.h" @@ -69,6 +70,7 @@ struct PartitionParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) std::reference_wrapper on_partition_assignment_fn; + LayeringIndex* layering_index; }; } // namespace @@ -150,6 +152,7 @@ struct GetCapabilityForEPParams { IResourceAccountant* resource_accountant; std::reference_wrapper graph_optimizer_registry; std::reference_wrapper check_load_cancellation_fn; + LayeringIndex* layering_index; // Added member }; auto get_capabilities = [](const IExecutionProvider& ep, @@ -193,10 +196,94 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l auto& capabilities = params.capabilities.get(); const auto& graph_optimizer_registry = params.graph_optimizer_registry.get(); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + InlinedVector assigned_filtered_in_nodes; + InlinedVector filtered_in_nodes; +#endif + // Helper to create a GraphViewer that filters nodes based on layering_index if present. + auto create_graph_viewer = [&](std::unique_ptr& out_sub_graph, + std::unique_ptr& out_viewer) -> Status { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + if (params.layering_index) { + assigned_filtered_in_nodes.clear(); + filtered_in_nodes.clear(); + filtered_in_nodes.reserve(graph.NumberOfNodes()); + + auto rules_opt = params.layering_index->GetLayeringRulesForThisEp(ep_type); + if (rules_opt) { + assigned_filtered_in_nodes.reserve(rules_opt->get().size()); + } + + for (auto& node : graph.Nodes()) { + auto rule_idx_opt = params.layering_index->GetNodeAssignment(graph, node.Index()); + bool include = true; + if (rule_idx_opt) { + // If node has an assignment, include it only if it is assigned to this EP + if (!rules_opt || rules_opt->get().count(*rule_idx_opt) == 0) { + include = false; + } else { + assigned_filtered_in_nodes.push_back(node.Index()); + } + } + // If node has no assignment, it is included (available to any EP) + + if (include) { + filtered_in_nodes.push_back(&node); + } + } + ORT_RETURN_IF_ERROR(graph_utils::CreateFilteredIndexedGraph(filtered_in_nodes, graph, out_sub_graph)); + out_viewer = std::make_unique(graph, *out_sub_graph); + return Status::OK(); + } +#else + ORT_UNUSED_PARAMETER(out_sub_graph); +#endif + out_viewer = std::make_unique(graph); + return Status::OK(); + }; + // Helper to un-assign nodes that were assigned to this EP but not claimed by updated capabilities. + auto reset_assignment_unclaimed_nodes = [&]() { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + if (params.layering_index) { + auto rules_opt = params.layering_index->GetLayeringRulesForThisEp(ep_type); + if (rules_opt) { + const auto& ep_rules = rules_opt->get(); + InlinedHashSet claimed; + for (const auto& cap : capabilities) { + if (cap && cap->sub_graph) { + for (auto idx : cap->sub_graph->nodes) claimed.insert(idx); + } + } + + // Check if all assigned filtered-in nodes are claimed + // and if not make them available for subsequent EPs + for (auto& node_index : assigned_filtered_in_nodes) { + if (claimed.count(node_index) == 0) { + auto rule_idx_opt = params.layering_index->GetNodeAssignment(graph, node_index); + if (rule_idx_opt && ep_rules.count(*rule_idx_opt) > 0) { + params.layering_index->MakeNodeUnassigned(graph, node_index); + } + } + } + assigned_filtered_in_nodes.clear(); + } + } +#endif + }; + { - const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + std::unique_ptr sub_graph_holder; + std::unique_ptr graph_viewer; + ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer)); + + if (params.resource_accountant) { + params.resource_accountant->ResetPendingWeights(); + } + capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + + reset_assignment_unclaimed_nodes(); + if (params.check_load_cancellation_fn()) { return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph partitioning was canceled by user request"); @@ -241,9 +328,33 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); - const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + if (params.layering_index && end_node > first_new_node) { + // We need to update the LayeringIndex with newly created nodes + // as the layout transformation may have created new nodes + // with inherited annotations + InlinedVector new_node_indices; + for (NodeIndex idx = first_new_node; idx < end_node; ++idx) { + if (graph.GetNode(idx) != nullptr) { + new_node_indices.push_back(idx); + } + } + params.layering_index->Update(graph, new_node_indices); + } +#endif + + std::unique_ptr sub_graph_holder; + std::unique_ptr graph_viewer; + ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer)); + + if (params.resource_accountant) { + params.resource_accountant->ResetPendingWeights(); + } + capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + + reset_assignment_unclaimed_nodes(); + if (params.check_load_cancellation_fn()) { return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "GetCapabilities was canceled by user request"); @@ -388,13 +499,13 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, fused_node->SetExecutionProviderType(provider_type); if (acc_enabled) { - // We account for the fused node. We operate under assumption - // that the fused node would use no more memory when the nodes we are fusing. - // and potentially less than that, and therefore, no threshold check is needed here. - // All threshold checks are done within the EP. - capability.ComputeAndAccountForNode(*fused_node); + // Account for all constituent nodes using the per-node costs computed + // during GetCapability() (which already includes within-pass weight dedup). + // Computing the cost for the newly created fused node would undercount + // because the fused node often doesn't expose all original initializers, + // and would commit weights for the wrong node index. + capability.AccountForAllNodes(); } - result = fused_node; } else { // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. @@ -430,7 +541,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, const OnPartitionAssignmentFunction& on_partition_assignment_fn, const logging::Logger& logger, IResourceAccountant* resource_accountant, const GraphOptimizerRegistry& graph_optimizer_registry, - bool disable_model_compile) { + bool disable_model_compile, + LayeringIndex* layering_index) { // Added arg // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -448,7 +560,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, check_load_cancellation_fn, on_partition_assignment_fn, logger, resource_accountant, - graph_optimizer_registry, disable_model_compile)); + graph_optimizer_registry, disable_model_compile, + layering_index)); // Pass through } } @@ -474,7 +587,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(debug_graph_fn), resource_accountant, std::ref(graph_optimizer_registry), - std::cref(check_load_cancellation_fn)}; + std::cref(check_load_cancellation_fn), + layering_index}; // Pass param ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -654,17 +768,17 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, } // expand any nodes that have an ONNX function definition but no matching ORT kernel -static Status InlineNodes(Graph& graph, bool& modified_graph) { +static Status InlineNodes(Graph& graph, bool& modified_graph, LayeringIndex* layering_index) { // recurse into nested graphs first so we process from bottom up for (auto& node : graph.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { Graph* subgraph = entry.second; - ORT_RETURN_IF_ERROR(InlineNodes(*subgraph, modified_graph)); + ORT_RETURN_IF_ERROR(InlineNodes(*subgraph, modified_graph, layering_index)); } } - // See if the node with no provider can be inlined. If one such nodes can be - // successfully inlined, we re-run the partitioner on the modified graph. + // See if the node with no provider can be inlined. If one such nodes can be successfully inlined, + // we re-run the partitioner on the modified graph. // NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating // using graph.Nodes(). InlinedVector nodes_to_inline; @@ -674,9 +788,50 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { } } + // Collect new node indices for nodes inlined from annotated parents so we can + // update the LayeringIndex in one batch. + InlinedVector new_node_indices; + for (auto* node : nodes_to_inline) { + // Check for an effective layering assignment: either from an explicit annotation + // on the node, or from an inherited assignment via the LayeringIndex (e.g., a function + // call node inside an annotated If/Loop subgraph that inherited its parent's rule). + const bool has_explicit_annotation = !node->GetLayeringAnnotation().empty(); + bool has_effective_assignment = has_explicit_annotation; + + if (layering_index != nullptr && !has_explicit_annotation) { + // The node may have an inherited-only assignment with no stored annotation string. + // Materialize the annotation on the node so Graph::InlineFunction propagates it + // to the newly created inlined nodes. + auto rule_idx = layering_index->GetNodeAssignment(graph, node->Index()); + if (rule_idx) { + has_effective_assignment = true; + const auto& rules = layering_index->GetRules(); + if (*rule_idx < rules.rules.size()) { + node->SetLayeringAnnotation(rules.rules[*rule_idx].annotation); + } + } + } + + const int max_before = has_effective_assignment ? graph.MaxNodeIndex() : 0; + ORT_RETURN_IF_ERROR(graph.InlineFunction(*node)); modified_graph = true; + + if (has_effective_assignment) { + const int max_after = graph.MaxNodeIndex(); + for (int i = max_before; i < max_after; ++i) { + if (graph.GetNode(static_cast(i)) != nullptr) { + new_node_indices.push_back(static_cast(i)); + } + } + } + } + + // Update the LayeringIndex so the next partitioning round filters correctly + // for the newly inlined nodes that inherited their parent's annotation. + if (layering_index != nullptr && !new_node_indices.empty()) { + layering_index->Update(graph, new_node_indices); } return Status::OK(); @@ -1018,7 +1173,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_manager, const std::optional& acc_map, const GraphOptimizerRegistry& graph_optimizer_registry, - const logging::Logger& logger, bool disable_model_compile) { + const logging::Logger& logger, bool disable_model_compile) { // Added arg bool modified_graph = false; auto& graph = partition_params.graph.get(); @@ -1046,12 +1201,13 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, check_load_cancellation_fn, on_partition_assignment_fn, logger, resource_accountant, graph_optimizer_registry, - disable_model_compile)); + disable_model_compile, + partition_params.layering_index)); // Pass param } // expand any nodes that have an ONNX function definition but no matching ORT kernel. modified_graph = false; - ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph)); + ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph, partition_params.layering_index)); // Resolve and rerun graph partitioning and inlining if there was a change if (modified_graph) { @@ -1101,7 +1257,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) nullptr, std::ref(graph_optimizer_registry), - partition_params.check_load_cancellation_fn + partition_params.check_load_cancellation_fn, + partition_params.layering_index }; // clang-format on @@ -1135,7 +1292,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param Node& fused_node = graph.BeginFuseSubGraph(indexed_sub_graph, node_name); fused_node.SetExecutionProviderType(type); if (indexed_sub_graph.IsAccountingEnabled()) { - indexed_sub_graph.ComputeAndAccountForNode(fused_node); + indexed_sub_graph.AccountForAllNodes(); } // create filtered graph viewer for this set of nodes @@ -1143,6 +1300,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // TODO: Could avoid the topological sort in the GraphViewer ctor by constructing from an existing // GraphViewer instance instead of the Graph (copying the topological order instead of recalculating). auto viewer = std::make_unique(graph, indexed_sub_graph); + compilation_entries.push_back(CompilationEntry{std::move(viewer), fused_node, *capability}); #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Compiling capabilities is not supported in this build."); @@ -1153,7 +1311,6 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // We will compile the fused nodes one by one, and fuse the subgraph if successful. for (const auto& compilation_entry : compilation_entries) { - const bool acc_enabled = compilation_entry.capability.get().sub_graph->IsAccountingEnabled(); Node& node = compilation_entry.fused_node; std::vector single_node_compute_func; ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}}, @@ -1184,9 +1341,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one graph.FinalizeFuseSubGraph(indexed_sub_graph, node); - if (acc_enabled) { - compilation_entry.capability.get().sub_graph->ComputeAndAccountForNode(node); - } + // accounting was already done via AccountForAllNodes() when the fused node was created above. } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1259,9 +1414,10 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, const ConfigOptions& config_options, const logging::Logger& logger, + LayeringIndex* layering_index, Mode mode, const epctx::ModelGenOptions& ep_context_gen_options, - const layout_transformation::DebugGraphFn& debug_graph_fn) const { + const layout_transformation::DebugGraphFn& debug_graph_fn) const { // Added arg // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. // 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet. @@ -1292,7 +1448,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(fused_node_unique_id), std::cref(transform_layout_function), std::cref(debug_graph_fn), - std::cref(on_partition_assignment_fn_)}; + std::cref(on_partition_assignment_fn_), + layering_index}; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1303,7 +1460,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(graph), std::cref(check_load_cancellation_fn), std::cref(on_partition_assignment_fn_), - }; + layering_index}; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1323,12 +1480,12 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, // We use this only if Resource Aware Partitioning is enabled for any of the EPs // The map is empty if not created if not enabled std::optional ep_acc_map; - ORT_RETURN_IF_ERROR(NodeStatsRecorder::CreateAccountants(config_options, graph.ModelPath(), ep_acc_map)); + ORT_RETURN_IF_ERROR(CreateAccountants(config_options, graph.ModelPath(), ep_acc_map)); bool disable_model_compile = config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "1"; ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, ep_acc_map, *graph_optimizer_registry_, logger, - disable_model_compile)); + disable_model_compile)); // Pass param if (ep_context_gen_options.enable) { ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_gen_options, logger)); diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index eb70b9f89933d..4de9d94781b18 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; +class LayeringIndex; class Model; struct ConfigOptions; @@ -60,6 +61,7 @@ class GraphPartitioner { const layout_transformation::TransformLayoutFunction& transform_layout_function, const ConfigOptions& config_options, const logging::Logger& logger, + LayeringIndex* layering_index, Mode mode = Mode::kNormal, const epctx::ModelGenOptions& ep_context_gen_options = {}, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/framework/layering_annotations.cc b/onnxruntime/core/framework/layering_annotations.cc new file mode 100644 index 0000000000000..91df102abef17 --- /dev/null +++ b/onnxruntime/core/framework/layering_annotations.cc @@ -0,0 +1,584 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "core/graph/constants.h" +#include "core/common/narrow.h" +#include "core/common/parse_string.h" +#include "core/common/string_utils.h" +#include "core/framework/layering_annotations.h" +#include "core/framework/ortmemoryinfo.h" +#include "core/session/abi_devices.h" +#include "core/framework/execution_providers.h" +#include "core/graph/graph.h" + +#include + +namespace onnxruntime { + +common::Status LayeringRules::FromConfigString(const std::string& config_value, LayeringRules& rules) { + rules.rules.clear(); + if (config_value.empty()) { + return common::Status::OK(); + } + + // Track seen annotations to reject duplicates. + // Separate sets for exact and prefix match annotations. + InlinedHashSet seen_exact_annotations; + InlinedHashSet seen_prefix_annotations; + + auto entries = utils::SplitString(config_value, ";"); + for (const auto& e : entries) { + auto entry = utils::TrimString(e); + if (entry.empty()) { + continue; + } + + const size_t open_paren = entry.find('('); + const size_t close_paren = entry.find(')'); + + if (open_paren == std::string::npos) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid layering config: Missing '(' in entry: ", entry); + } + if (close_paren == std::string::npos) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid layering config: Missing ')' in entry: ", entry); + } + if (close_paren < open_paren) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid layering config: ')' comes before '(' in entry: ", entry); + } + + std::string device = entry.substr(0, open_paren); + device = utils::TrimString(device); + + if (device.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid layering config: Empty device name in entry: ", entry); + } + + std::string annotations_list = entry.substr(open_paren + 1, close_paren - open_paren - 1); + auto annotations = utils::SplitString(annotations_list, ","); + for (auto& a : annotations) { + auto ann = utils::TrimString(a); + if (ann.empty()) { + continue; + } + + bool prefix_match = true; + if (ann[0] == '=') { + prefix_match = false; + ann = ann.substr(1); + ann = utils::TrimString(ann); + } + + if (ann.empty()) { + continue; + } + + // Check for duplicate annotation (same annotation string and match type) + auto& seen_set = prefix_match ? seen_prefix_annotations : seen_exact_annotations; + auto [it, inserted] = seen_set.insert(ann); + if (!inserted) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid layering config: Duplicate ", (prefix_match ? "prefix" : "exact"), + " match annotation '", ann, "' found in entry: ", entry); + } + + rules.rules.push_back({device, std::move(ann), prefix_match}); + } + } + + return common::Status::OK(); +} + +LayeringRuleMatcher::LayeringRuleMatcher(const LayeringRules& rules) { + for (size_t i = 0; i < rules.rules.size(); ++i) { + const auto& rule = rules.rules[i]; + ORT_ENFORCE(!rule.annotation.empty(), "Layering rule annotation cannot be empty"); + if (rule.prefix_match) { + AddPrefixRule(rule.annotation, i); + } else { + AddExactRule(rule.annotation, i); + } + } +} + +std::optional LayeringRuleMatcher::Match(const std::string& node_annotation) const { + std::optional best_match = std::nullopt; + + // 1. Check Prefix Matches via Trie. Prefix have priority over exact matches. + const TrieNode* current = &root_; + + // No empty annotations + // so we omit checking the root. + + for (char c : node_annotation) { + if (best_match && *best_match == 0) { + // Optimization: If we already found index 0, we can't do better. + return best_match; + } + + auto child_it = current->children.find(c); + if (child_it == current->children.end()) { + break; + } + current = child_it->second.get(); + if (current->rule_index) { + UpdateBestMatch(best_match, *current->rule_index); + } + } + + if (best_match) { + return best_match; + } + + // 2. Check Exact Matches (fallback) + auto it = exact_match_rules_.find(node_annotation); + if (it != exact_match_rules_.end()) { + best_match = it->second; + } + + return best_match; +} + +namespace { +bool CaseInsensitiveCompare(std::string_view a, std::string_view b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end(), + [](char c1, char c2) { + return std::tolower(static_cast(c1)) == + std::tolower(static_cast(c2)); + }); +} + +bool TryParseIndex(const std::string& str, uint32_t& index) { + if (str.empty()) return false; + return TryParseStringWithClassicLocale(str, index); +} + +// Sentinel value representing an unknown/unavailable device type. +// Used when an OrtEpDevice has neither hardware info nor memory info, +// so we cannot determine the actual device type. +constexpr OrtDevice::DeviceType kDeviceTypeUnknown = static_cast(-1); + +// Normalized view of an EP's device properties used by the matching logic. +// All fields are non-owning references or value types. +struct EpDeviceView { + std::string_view ep_name; + OrtDevice::DeviceType device_type; // OrtDevice::CPU, GPU, NPU, FPGA, or kDeviceTypeUnknown + uint32_t vendor_id; + OrtDevice::DeviceId device_id; + std::string_view vendor_string; // from OrtHardwareDevice::vendor (empty if unavailable) +}; + +bool MatchEpDevice(const EpDeviceView& ep, + std::string_view target_type_str, + std::string_view target_specifier, + std::string_view target_full) { + // "cpu" + if (CaseInsensitiveCompare(target_type_str, "cpu")) { + return ep.ep_name == kCpuExecutionProvider || + ep.device_type == OrtDevice::CPU; + } + // "gpu" + if (CaseInsensitiveCompare(target_type_str, "gpu")) { + if (target_specifier.empty()) { + if (ep.device_type == OrtDevice::GPU) return true; + // Heuristic fallback for common GPU EPs if hardware info is missing + return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kDmlExecutionProvider; + } + // "gpu:" or "gpu:" + if (ep.device_type == OrtDevice::GPU) { + uint32_t index = std::numeric_limits::max(); + if (TryParseIndex(std::string(target_specifier), index)) { + return ep.device_id == static_cast(index); + } + // gpu: + if (!ep.vendor_string.empty() && CaseInsensitiveCompare(ep.vendor_string, target_specifier)) { + return true; + } + if (CaseInsensitiveCompare(target_specifier, "nvidia") && + ep.vendor_id == OrtDevice::VendorIds::NVIDIA) return true; + if (CaseInsensitiveCompare(target_specifier, "amd") && + ep.vendor_id == OrtDevice::VendorIds::AMD) return true; + if (CaseInsensitiveCompare(target_specifier, "intel") && + ep.vendor_id == OrtDevice::VendorIds::INTEL) return true; + // Heuristic: gpu:nvidia -> CUDA + if (CaseInsensitiveCompare(target_specifier, "nvidia") && + ep.ep_name == kCudaExecutionProvider) return true; + } + return false; + } + // "accelerator" (not cpu) + if (CaseInsensitiveCompare(target_type_str, "accelerator")) { + // Match if the EP is not a known CPU provider and its device type + // is not definitively CPU. Unknown device type (no HW/mem info) + // is treated as a potential accelerator. + return ep.ep_name != kCpuExecutionProvider && ep.device_type != OrtDevice::CPU; + } + // "npu" + if (CaseInsensitiveCompare(target_type_str, "npu")) { + if (ep.device_type == OrtDevice::NPU) return true; + return ep.ep_name == kQnnExecutionProvider || ep.ep_name == kVitisAIExecutionProvider; + } + // "fpga" + if (CaseInsensitiveCompare(target_type_str, "fpga")) { + return ep.device_type == OrtDevice::FPGA; + } + // "cuda" + if (CaseInsensitiveCompare(target_type_str, "cuda")) { + return ep.ep_name == kCudaExecutionProvider; + } + // "dml" + if (CaseInsensitiveCompare(target_type_str, "dml")) { + return ep.ep_name == kDmlExecutionProvider; + } + // Fallback: exact EP name match + return ep.ep_name == target_full; +} + +void ParseDeviceTarget(const std::string& target_full, + std::string& target_type_str, + std::string& target_specifier) { + const auto colon_pos = target_full.find(':'); + target_type_str = (colon_pos == std::string::npos) ? target_full : target_full.substr(0, colon_pos); + target_specifier = (colon_pos != std::string::npos) ? target_full.substr(colon_pos + 1) : std::string(); +} + +} // namespace + +std::optional EpLayeringMatcher::Match(gsl::span ep_devices, + const LayerAnnotation& rule) { + std::string target_type_str, target_specifier; + ParseDeviceTarget(rule.device, target_type_str, target_specifier); + + for (const auto* ep_device_ptr : ep_devices) { + if (!ep_device_ptr) continue; + const OrtEpDevice& ep_device = *ep_device_ptr; + + // Build normalized view from OrtEpDevice. + // Device type comes from either the hardware device or the memory info, + // with hardware device taking priority. If neither is available, + // device_type is set to kDeviceTypeUnknown. + OrtDevice::DeviceType device_type = kDeviceTypeUnknown; + bool has_hw = ep_device.device != nullptr; + if (has_hw) { + // Map OrtHardwareDeviceType to OrtDevice::DeviceType + switch (ep_device.device->type) { + case OrtHardwareDeviceType_GPU: + device_type = OrtDevice::GPU; + break; + case OrtHardwareDeviceType_NPU: + device_type = OrtDevice::NPU; + break; + case OrtHardwareDeviceType_CPU: + device_type = OrtDevice::CPU; + break; + default: + device_type = kDeviceTypeUnknown; + break; + } + } else if (ep_device.device_memory_info) { + device_type = ep_device.device_memory_info->device.Type(); + } + + EpDeviceView view{ + ep_device.ep_name, + device_type, + has_hw ? ep_device.device->vendor_id : 0u, + has_hw ? static_cast(ep_device.device->device_id) : OrtDevice::DeviceId{}, + has_hw ? std::string_view(ep_device.device->vendor) : std::string_view{}}; + + if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) { + return std::string(ep_device.ep_name); + } + } + return std::nullopt; +} + +std::optional EpLayeringMatcher::Match(const ExecutionProviders& providers, + const LayerAnnotation& rule) { + std::string target_type_str, target_specifier; + ParseDeviceTarget(rule.device, target_type_str, target_specifier); + + for (const auto& ep_shared_ptr : providers) { + if (!ep_shared_ptr) continue; + const IExecutionProvider& ep = *ep_shared_ptr; + const OrtDevice& device = ep.GetDevice(); + + EpDeviceView view{ + ep.Type(), + device.Type(), + device.Vendor(), + device.Id(), + {}}; // no vendor string available from IExecutionProvider + + if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) { + return std::string(ep.Type()); + } + } + return std::nullopt; +} + +LayeringIndex LayeringIndex::Create(const Graph& graph, + EpNameToLayeringIndices ep_map, + LayeringIndexToEpName rule_map, + LayeringRules layering_rules) { + // 1. Create LayeringIndex instance with pre-computed maps + LayeringIndex index(std::move(layering_rules), std::move(ep_map), std::move(rule_map)); + + // 2. Traverse the graph and index nodes + index.ProcessGraph(graph, std::nullopt); + + return index; +} + +Status LayeringIndex::Create(const Graph& graph, + const std::string& config_string, + gsl::span ep_devices, + const ExecutionProviders& ep_providers, + const logging::Logger& logger, + std::optional& layering_index) { + LayeringRules rules; + ORT_RETURN_IF_ERROR(LayeringRules::FromConfigString(config_string, rules)); + + LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " layering rules from config."; + + if (rules.rules.empty()) { + // Return no index indicating no layering + layering_index.reset(); + return Status::OK(); + } + + // Identify which EPs satisfy which rules + EpNameToLayeringIndices ep_map; + LayeringIndexToEpName rule_map; + + size_t matched_rule_count = 0; + + for (size_t i = 0, lim = rules.rules.size(); i < lim; ++i) { + const auto& rule = rules.rules[i]; + + // 1. Try matching against ep_devices (from session options) + std::optional matched_ep; + if (!ep_devices.empty()) { + matched_ep = EpLayeringMatcher::Match(ep_devices, rule); + } + + // 2. If not matched, try matching against Registered EPs + if (!matched_ep) { + matched_ep = EpLayeringMatcher::Match(ep_providers, rule); + } + + if (matched_ep) { + const std::string& ep_type = *matched_ep; + ep_map[ep_type].insert(i); + // Ensure 1:1 mapping from rule index to EP type + // Note: A rule index refers to a unique entry in LayeringRules::rules vector. + // So 'i' is unique. + rule_map[i] = ep_type; + matched_rule_count++; + LOGS(logger, VERBOSE) << "Layering Rule " << i << " (" << rule.device << " -> " << rule.annotation + << ") mapped to EP: " << ep_type; + } else { + LOGS(logger, WARNING) << "Layering Rule " << i << " (" << rule.device << " -> " << rule.annotation + << ") could not be mapped to any available Execution Provider."; + } + } + + LOGS(logger, INFO) << "LayeringIndex created. Matched " << matched_rule_count + << " out of " << rules.rules.size() << " rules to available Execution Providers."; + + layering_index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + return Status::OK(); +} + +void LayeringIndex::ProcessGraph(const Graph& graph, std::optional parent_layer_id) { + // 3. Create entry for this graph instance + bool was_updated = false; + std::optional new_index; + GraphLayeringIndex* current_graph_index_ptr = nullptr; + auto found = graph_index_.find(&graph); + if (found != graph_index_.end()) { + current_graph_index_ptr = &found->second; + } else { + new_index.emplace(); + current_graph_index_ptr = &(*new_index); + } + GraphLayeringIndex& current_graph_index = *current_graph_index_ptr; + + for (auto& node : graph.Nodes()) { + std::optional matched_rule_idx = std::nullopt; + + // 4. For every node query its annotation + const std::string& annotation = node.GetLayeringAnnotation(); + if (!annotation.empty()) { + // If it has an annotation try to match it + matched_rule_idx = matcher_.Match(annotation); + } + + // 5. If node has no annotation, inherit from subgraph parent node + if (!matched_rule_idx && parent_layer_id) { + matched_rule_idx = parent_layer_id; + } + + // Record assignment if we have a match + if (matched_rule_idx) { + const size_t rule_idx = *matched_rule_idx; + + // Only assign if this rule maps to a valid EP in our configuration + if (layering_index_to_ep_name_.count(rule_idx)) { + ORT_IGNORE_RETURN_VALUE(current_graph_index.node_to_layering_index_.insert_or_assign(node.Index(), rule_idx)); + ORT_IGNORE_RETURN_VALUE(current_graph_index.layer_to_node_ids_[rule_idx].insert(node.Index())); + was_updated = true; + } else { + // reset since no valid EP mapping + matched_rule_idx = std::nullopt; + } + } + + // Recurse for subgraphs + if (node.ContainsSubgraph()) { + const std::optional subgraph_parent_assignment = matched_rule_idx; + for (auto& [attr_name, subgraph] : node.GetAttributeNameToSubgraphMap()) { + ProcessGraph(*subgraph, subgraph_parent_assignment); + } + } + } + if (was_updated && new_index) { + graph_index_.emplace(&graph, std::move(*new_index)); + } +} + +void LayeringIndex::Update(const Graph& graph, gsl::span nodes) { + // Ensure we have an entry for this graph (creating it if it doesn't exist, though typically it should) + bool was_updated = false; + std::optional new_index; + GraphLayeringIndex* current_graph_index_ptr = nullptr; + auto found = graph_index_.find(&graph); + if (found != graph_index_.end()) { + current_graph_index_ptr = &found->second; + } else { + new_index.emplace(); + current_graph_index_ptr = &(*new_index); + } + + auto& current_graph_index = *current_graph_index_ptr; + + for (NodeIndex node_index : nodes) { + // GetMutableNode because we want to ClearLayeringAnnotation if we use it + const Node* node = graph.GetNode(node_index); + if (!node) { + continue; + } + + const std::string& annotation = node->GetLayeringAnnotation(); + if (!annotation.empty()) { + auto matched_rule_idx = matcher_.Match(annotation); + + if (matched_rule_idx) { + const size_t rule_idx = *matched_rule_idx; + + // Only assign if this rule maps to a valid EP in our configuration + if (layering_index_to_ep_name_.count(rule_idx)) { + // Check if already assigned to a DIFFERENT rule, if so clean up old mapping + auto prev_assign = current_graph_index.node_to_layering_index_.find(node_index); + if (prev_assign != current_graph_index.node_to_layering_index_.end()) { + size_t old_rule = prev_assign->second; + if (old_rule != rule_idx) { + current_graph_index.layer_to_node_ids_[old_rule].erase(node_index); + } + } + + ORT_IGNORE_RETURN_VALUE(current_graph_index.node_to_layering_index_.insert_or_assign(node_index, rule_idx)); + ORT_IGNORE_RETURN_VALUE(current_graph_index.layer_to_node_ids_[rule_idx].insert(node_index)); + was_updated = true; + } + } + } + } + if (was_updated && new_index) { + graph_index_.emplace(&graph, std::move(*new_index)); + } +} + +void LayeringRuleMatcher::AddExactRule(const std::string& annotation, size_t index) { + // Only store the first occurrence (lowest index) + exact_match_rules_.insert({annotation, index}); +} + +void LayeringRuleMatcher::AddPrefixRule(const std::string& annotation, size_t index) { + TrieNode* current = &root_; + for (char c : annotation) { + auto p = current->children.insert({c, nullptr}); + if (p.second) { + p.first->second = std::make_unique(); + } + current = p.first->second.get(); + } + + // Only store if strictly better (lower index) or not set + // Since we iterate rules 0..N, if a rule index is already set for this node, + // it corresponds to a higher priority rule, so we skip overwriting it. + if (!current->rule_index) { + current->rule_index = index; + } +} + +void LayeringRuleMatcher::UpdateBestMatch(std::optional& current_best, size_t candidate) const { + if (!current_best || candidate < *current_best) { + current_best = candidate; + } +} + +std::optional>> +LayeringIndex::GetLayeringRulesForThisEp(const std::string& ep_type) const { + auto hit = ep_name_to_layering_indices_.find(ep_type); + if (hit == ep_name_to_layering_indices_.end()) { + return {}; + } + return hit->second; +} + +std::optional LayeringIndex::GetNodeAssignment(const Graph& graph, NodeIndex node_id) const { + auto hit = graph_index_.find(&graph); + if (hit == graph_index_.end()) { + return {}; + } + + // Nodes in subgraph that were not annotated has already inherited their + // annotation if any from the parent node of the subgraph + const auto& graph_layering_index = hit->second; + auto layer_hit = graph_layering_index.node_to_layering_index_.find(node_id); + if (layer_hit != graph_layering_index.node_to_layering_index_.end()) { + return layer_hit->second; + } + return {}; +} + +void LayeringIndex::MakeNodeUnassigned(const Graph& graph, NodeIndex node_id) { + auto hit = graph_index_.find(&graph); + if (hit == graph_index_.end()) { + return; + } + auto& graph_layering_index = hit->second; + auto node_to_layer_hit = graph_layering_index.node_to_layering_index_.find(node_id); + std::optional layer_idx; + if (node_to_layer_hit != graph_layering_index.node_to_layering_index_.end()) { + // Get the layer index + layer_idx = node_to_layer_hit->second; + graph_layering_index.node_to_layering_index_.erase(node_to_layer_hit); + } + // Remove node from layer collection + if (layer_idx) { + auto layer_to_nodes_hit = graph_layering_index.layer_to_node_ids_.find(*layer_idx); + if (layer_to_nodes_hit != graph_layering_index.layer_to_node_ids_.end()) { + layer_to_nodes_hit->second.erase(node_id); + if (layer_to_nodes_hit->second.empty()) { + graph_layering_index.layer_to_node_ids_.erase(layer_to_nodes_hit); + } + } + } +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/layering_annotations.h b/onnxruntime/core/framework/layering_annotations.h new file mode 100644 index 0000000000000..5d58e9ace2471 --- /dev/null +++ b/onnxruntime/core/framework/layering_annotations.h @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "core/common/inlined_containers.h" +#include "core/common/status.h" +#include "core/graph/basic_types.h" +#include "core/common/logging/logging.h" +#include "gsl/gsl" +#include +#include +#include +#include + +struct OrtEpDevice; + +namespace onnxruntime { +class ExecutionProviders; +class Graph; + +/// +/// Annotation extracted from kOrtSessionOptionsLayerAssignmentSettings session configuration option. +/// +struct LayerAnnotation { + std::string device; + std::string annotation; + bool prefix_match; +}; + +/// +/// This struct is a container for layering rules extracted from the kOrtSessionOptionsLayerAssignmentSettings +/// session configuration option. +/// +struct LayeringRules { + std::vector rules; + /// + /// Parses the layering rules from the given configuration string. + /// The configuration string is in the following format.: + /// 'cpu(L1,L2); gpu(L3,=L4)' where cpu or gpu denote the target EP. + /// L1, L2, L3 are annotations that can be matched to node annotations in the graph. The '=' prefix denotes + /// exact match. The position of the annotation (L1, L2, L3) in the list denotes its priority in matching (left to right). + /// However, the prefix annotations will always have higher priority than the exact match annotations regardless + /// of their position in the list. In the above example, L1 has the highest priority, followed by L2, + /// then L3 and finally L4. The rules are separated by ';' and there can be multiple rules for different EPs. + /// + /// The configuration string to parse. + /// Output parameter where the parsed rules will be stored. + /// Status indicating success or failure (e.g. due to format errors). + static common::Status FromConfigString(const std::string& config_value, LayeringRules& rules); +}; + +/// +/// This class matches node annotations against layering rules. +/// +class LayeringRuleMatcher { + public: + explicit LayeringRuleMatcher(const LayeringRules& rules); + + /// + /// The method returns the index of the best matching rule for the given annotation + /// if it exists + /// + /// annotation retrieved from protobuf node metadata + /// index of the matching LayeringRule if it exists + std::optional Match(const std::string& node_annotation) const; + + private: + struct TrieNode { + InlinedHashMap> children; + std::optional rule_index; + }; + + TrieNode root_; + InlinedHashMap exact_match_rules_; + + void AddExactRule(const std::string& annotation, size_t index); + + void AddPrefixRule(const std::string& annotation, size_t index); + + void UpdateBestMatch(std::optional& current_best, size_t candidate) const; +}; + +namespace EpLayeringMatcher { +/// +/// Matches a list of available OrtEpDevices against the device string specified in the LayerAnnotation. +/// Returns the EP Type string of the first device that matches the rule. +/// +/// The list of available EP devices. +/// The rule containing the device designator. +/// Optional containing the matched EP type, nullopt otherwise. +std::optional Match(gsl::span ep_devices, + const LayerAnnotation& rule); + +/// +/// Matches a collection of ExecutionProviders against the device string specified in the LayerAnnotation. +/// Returns the EP Type string of the first provider that matches the rule. +/// +/// The collection of available Execution Providers. +/// The rule containing the device designator. +/// Optional containing the matched EP type, nullopt otherwise. +std::optional Match(const ExecutionProviders& providers, const LayerAnnotation& rule); +} // namespace EpLayeringMatcher + +// This class contains indexing information about the entire graph +// per sub-graph info is stored in graph_index_ +class LayeringIndex { + public: + // mapping of EP name/type to a set of LayeringRule indices mapped to that EP. + using EpNameToLayeringIndices = InlinedHashMap>; + // mapping of LayeringRule index to EP name/type, reverse of the above + using LayeringIndexToEpName = InlinedHashMap; + + /// + /// Creates a fully initialized LayeringIndex. + /// + /// The graph to traverse and index. + /// Pre-populated mapping of EP names to their applicable rule indices. + /// Pre-populated mapping of rule indices to EP names. + /// Matcher to resolve node annotations to rule indices. + static LayeringIndex Create(const Graph& graph, + EpNameToLayeringIndices ep_map, + LayeringIndexToEpName rule_map, + LayeringRules layering_rules); + + /// + /// Factory method that creates a LayeringIndex by parsing configuration, matching rules against + /// available devices/providers, and indexing the graph. + /// + /// The graph to index. + /// The configuration string containing layering rules. + /// Available OrtEpDevices to match rules against. + /// Available ExecutionProviders to match rules against (fallback). + /// Logger for reporting information/errors. + /// Output parameter for the created LayeringIndex. Returns no index if + /// no valid layering rules discovered. + /// Status indicating success or failure. + static Status Create(const Graph& graph, + const std::string& config_string, + gsl::span ep_devices, + const ExecutionProviders& ep_providers, + const logging::Logger& logger, + std::optional& layering_index); + + // Returns the Layering Rule indices mapped to the EP if any + std::optional>> + GetLayeringRulesForThisEp(const std::string& ep_type) const; + + // Returns the parsed layering rules + const LayeringRules& GetRules() const noexcept { return rules_; } + + // This function returns an index for the Layering rule the node is assigned to if any + std::optional GetNodeAssignment(const Graph& graph, NodeIndex node_id) const; + + // This is used when an EP fails to claim a node during partitioning so we make it + // available for other EPs + void MakeNodeUnassigned(const Graph& graph, NodeIndex node_id); + /// + /// Updates the layering index for a specific set of nodes in a graph. + /// This checks if the nodes have annotations, and if so, matches them against the rules + /// and updates the assignment. + /// + /// The graph containing the nodes. + /// Indices of nodes to check and update. + void Update(const Graph& graph, gsl::span nodes); + + private: + LayeringRules rules_; + LayeringRuleMatcher matcher_; + // These stay constant + EpNameToLayeringIndices ep_name_to_layering_indices_; + LayeringIndexToEpName layering_index_to_ep_name_; + + using SetOfNodes = InlinedHashSet; + using LayerIndexToNodes = InlinedHashMap; + using NodeIndexToLayeringIndex = InlinedHashMap; + + /// + /// This struct contains the result of layering assignment for a graph. + /// The struct first reflects pre-assignment according to the configuration. + /// However, as we partition the graph, some nodes may be moved to unassigned sections + /// to make them available to subsequent partitioning passes. + /// + struct GraphLayeringIndex { + // Node to layering idx assignment map 1:1 + // If the node is not in this map, it is unassigned + NodeIndexToLayeringIndex node_to_layering_index_; + // This map contains mapping of LayeringRule index to the list of node ids + // Reverse from the above 1:M + LayerIndexToNodes layer_to_node_ids_; + }; + + LayeringIndex(LayeringRules layering_rules, EpNameToLayeringIndices ep_name_to_layering_indices, LayeringIndexToEpName layering_index_to_ep_name) + : rules_(std::move(layering_rules)), + matcher_(rules_), + ep_name_to_layering_indices_(std::move(ep_name_to_layering_indices)), + layering_index_to_ep_name_(std::move(layering_index_to_ep_name)) {} + + // Graph and sub-graphs mapping to their indices + InlinedHashMap graph_index_; + + void ProcessGraph(const Graph& graph, std::optional parent_layer_id); +}; + +} // namespace onnxruntime + +#else +namespace onnxruntime { +class LayeringIndex; +} +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/resource_accountant.cc b/onnxruntime/core/framework/resource_accountant.cc index 0665cc1951e60..68610ebb4be17 100644 --- a/onnxruntime/core/framework/resource_accountant.cc +++ b/onnxruntime/core/framework/resource_accountant.cc @@ -11,24 +11,31 @@ #include "core/framework/config_options.h" #include "core/framework/murmurhash3.h" +#include "core/framework/tensorprotoutils.h" #include "core/graph/constants.h" #include "core/graph/graph.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include +#include namespace onnxruntime { // Use this accountant if your resource can be counted with size_t type -class SizeTAccountant : public IResourceAccountant { +// This accountant uses NodeAllocationStats to compute resource consumption per node +// which can be collected and saved to a file OR loaded from a file and used for partitioning. +// This is currently used for CUDA EP. +class SizeBasedStatsAccountant : public IResourceAccountant { public: - SizeTAccountant() = default; - ~SizeTAccountant() = default; + SizeBasedStatsAccountant() = default; + ~SizeBasedStatsAccountant() = default; - SizeTAccountant(size_t threshold, InlinedHashMap&& node_stats) + SizeBasedStatsAccountant(size_t threshold, InlinedHashMap&& node_stats) : IResourceAccountant(threshold), node_stats_(std::move(node_stats)) {} - explicit SizeTAccountant(InlinedHashMap&& node_stats) + explicit SizeBasedStatsAccountant(size_t threshold) : IResourceAccountant(threshold) {} + + explicit SizeBasedStatsAccountant(InlinedHashMap&& node_stats) : IResourceAccountant(), node_stats_(std::move(node_stats)) {} ResourceCount GetConsumedAmount() const noexcept override { @@ -46,20 +53,99 @@ class SizeTAccountant : public IResourceAccountant { } } - ResourceCount ComputeResourceCount(const Node& node) const override { - const auto node_name = MakeUniqueNodeName(node); - auto hit = node_stats_.find(node_name); - if (hit != node_stats_.end()) { - const auto& stats = hit->second; - return stats.input_sizes + stats.initializers_sizes + - stats.total_dynamic_sizes + stats.total_temp_allocations; + ResourceCount ComputeResourceCount(const Node& node) override { + if (node_stats_) { + const auto node_name = MakeUniqueNodeName(node); + auto hit = node_stats_->find(node_name); + if (hit != node_stats_->end()) { + const auto& stats = hit->second; + return stats.input_sizes + stats.initializers_sizes + + stats.total_dynamic_sizes + stats.total_temp_allocations; + } + return static_cast(0U); + } else { + const auto* graph = node.GetContainingGraph(); + if (!graph) return static_cast(0); + + SafeInt total_size = 0; + for (const auto* input_def : node.InputDefs()) { + if (!input_def->Exists()) continue; + + const auto& name = input_def->Name(); + constexpr bool check_outer_scope = true; + const auto* tensor_proto = graph->GetInitializer(name, check_outer_scope); + + if (tensor_proto) { + // Skip if already committed from a previous partitioning iteration + if (committed_weights_.count(name) > 0) { + continue; + } + + // Skip if already pending from another node in this GetCapability pass + if (pending_weights_.count(name) > 0) { + continue; + } + + size_t size = 0; + auto status = utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &size); + + if (status.IsOK()) { + total_size += size; + pending_weights_.insert(name); + pending_weights_by_node_[node.Index()].insert(name); + } + } + } + + // Account for intermediate output tensors when shape info is available. + // GetSizeInBytesFromTensorTypeProto will only succeed when all dims are known + // (static shape) and a valid element type is present, so dynamic outputs are + // naturally skipped. + SafeInt output_size = 0; + for (const auto* output_def : node.OutputDefs()) { + if (!output_def->Exists() || !output_def->HasTensorOrScalarShape()) continue; + const auto* type_proto = output_def->TypeAsProto(); + if (!type_proto || !utils::HasTensorType(*type_proto)) continue; + + size_t size = 0; + if (utils::GetSizeInBytesFromTensorTypeProto<0>(type_proto->tensor_type(), &size).IsOK()) { + output_size += size; + } + } + + // Apply a safety multiplier for workspace/temp allocations we can't see + constexpr size_t kAdHocSafetyMultiplierPercent = 150; // 1.5x + SafeInt estimated = total_size + output_size; + return static_cast(estimated * kAdHocSafetyMultiplierPercent / 100); + } + } + + void ResetPendingWeights() override { + pending_weights_.clear(); + pending_weights_by_node_.clear(); + } + + void CommitWeightsForNode(NodeIndex node_index) override { + auto it = pending_weights_by_node_.find(node_index); + if (it != pending_weights_by_node_.end()) { + for (const auto& name : it->second) { + pending_weights_.erase(name); + } + committed_weights_.insert(it->second.begin(), it->second.end()); + pending_weights_by_node_.erase(it); } - return static_cast(0U); } private: size_t consumed_amount_ = 0; - InlinedHashMap node_stats_; + std::optional> node_stats_; + // Weights committed from previous partitioning iterations. + // These persist across GetCapability passes. + InlinedHashSet committed_weights_; + // Flat set of all pending weight names for O(1) membership checks. + InlinedHashSet pending_weights_; + // Same pending weights keyed by node index, used by CommitWeightsForNode. + InlinedHashMap> pending_weights_by_node_; }; struct NodeStatsRecorder::Impl { @@ -155,10 +241,11 @@ static Status LoadNodeAllocationStats( return Status::OK(); } -Status NodeStatsRecorder::CreateAccountants( +Status CreateAccountants( const ConfigOptions& config_options, const std::filesystem::path& model_path, std::optional& acc_map) { + std::optional result; // Check if CUDA partitioning settings are provided const std::string resource_partitioning_settings = config_options.GetConfigOrDefault( kOrtSessionOptionsResourceCudaPartitioningSettings, ""); @@ -166,29 +253,34 @@ Status NodeStatsRecorder::CreateAccountants( if (!resource_partitioning_settings.empty()) { auto splits = utils::SplitString(resource_partitioning_settings, ",", true); if (splits.size() == 2) { - if (splits[1].empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid resource partitioning settings"); - } - - InlinedHashMap loaded_stats; - ORT_RETURN_IF_ERROR(LoadNodeAllocationStats(model_path, splits[1], loaded_stats)); - - std::optional result; auto& map = result.emplace(); + std::optional cuda_memory_limit; if (!splits[0].empty()) { - size_t cuda_memory_limit = 0; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(std::string{splits[0]}, cuda_memory_limit)); - cuda_memory_limit = SafeInt(cuda_memory_limit) * 1024; // to bytes + cuda_memory_limit.emplace(0U); + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(std::string{splits[0]}, *cuda_memory_limit)); + cuda_memory_limit = SafeInt(*cuda_memory_limit) * 1024; // to bytes + } + + std::optional> loaded_stats; + if (!splits[1].empty()) { + loaded_stats.emplace(); + ORT_RETURN_IF_ERROR(LoadNodeAllocationStats(model_path, splits[1], *loaded_stats)); + } + + if (cuda_memory_limit && loaded_stats) { map.insert_or_assign(kCudaExecutionProvider, - std::make_unique(cuda_memory_limit, - std::move(loaded_stats))); - } else { + std::make_unique(*cuda_memory_limit, + std::move(*loaded_stats))); + } else if (cuda_memory_limit) { map.insert_or_assign(kCudaExecutionProvider, - std::make_unique(std::move(loaded_stats))); + std::make_unique(*cuda_memory_limit)); + } else if (loaded_stats) { + map.insert_or_assign(kCudaExecutionProvider, + std::make_unique(std::move(*loaded_stats))); + } else { + map.insert_or_assign(kCudaExecutionProvider, std::make_unique()); } - - acc_map = std::move(result); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid format for: ", kOrtSessionOptionsResourceCudaPartitioningSettings, @@ -196,6 +288,7 @@ Status NodeStatsRecorder::CreateAccountants( } } + acc_map = std::move(result); return Status::OK(); } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index bee7f048b7c6e..74fbe4d24de96 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -2531,5 +2531,18 @@ Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, std return UnpackInitializerData(initializer, std::filesystem::path(), unpacked_tensor); } +std::optional GetNodeProtoLayeringAnnotation(const ONNX_NAMESPACE::NodeProto& node_proto) { + std::optional result; + for (const auto& prop : node_proto.metadata_props()) { + if (prop.key() == kNodeProtoLayerAnnotation) { + if (!prop.value().empty()) { + result = prop.value(); + break; + } + } + } + return result; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index e7649c072416c..8b22e8d6d1c89 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -671,5 +671,15 @@ common::Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initiali */ common::Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, std::vector& unpacked_tensor); + +constexpr const char* kNodeProtoLayerAnnotation = "layer_ann"; + +/** + * This function examines the given node proto and looks into its metadata_props. + * It returns the first non-empty value found for the key kNodeProtoLayerAnnotation. + * A node is expected to have only one such annotation. + * If no non-empty annotation is found, std::nullopt is returned. + */ +std::optional GetNodeProtoLayeringAnnotation(const ONNX_NAMESPACE::NodeProto& node_proto); } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3599edbfcd357..e7da5a16930c6 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3935,6 +3935,20 @@ Status Graph::RemovedUnusedInitializersOrtFormat() { auto result = ForThisAndAllSubgraphs(all_subgraphs, cleanup_func); return result; } + +Status Graph::RemoveAllLayeringAnnotations() { + std::vector all_subgraphs; + FindAllSubgraphs(all_subgraphs); + auto cleanup_func = [](Graph& graph) { + for (auto& node : graph.Nodes()) { + node.ClearLayeringAnnotation(); + } + return Status::OK(); + }; + + return ForThisAndAllSubgraphs(all_subgraphs, cleanup_func); +} + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const std::string& Graph::Name() const noexcept { @@ -4371,6 +4385,13 @@ Node& Graph::AddNode(const Node& other) { &other.GetAttributes(), other.Domain()); + // Preserve layering annotation from the source node so that graph transformers + // that reconstruct nodes (or function inlining) retain the EP assignment hint. + const auto& annotation = other.GetLayeringAnnotation(); + if (!annotation.empty()) { + new_node.SetLayeringAnnotation(annotation); + } + return new_node; } @@ -4396,6 +4417,13 @@ Node& Graph::AddNode(const NodeProto& node_proto, &attributes, node_proto.domain()); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + auto maybe_annotation = utils::GetNodeProtoLayeringAnnotation(node_proto); + if (maybe_annotation) { + new_node.SetLayeringAnnotation(std::move(*maybe_annotation)); + } +#endif // + // Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to // calling onnx::check_node // NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain, @@ -4630,6 +4658,38 @@ Node& Graph::AddNode(const std::string& name, return *node; } +Node& Graph::AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + gsl::span input_args, + gsl::span output_args, + const Node& annotation_source, + const NodeAttributes* attributes, + const std::string& domain) { + auto& new_node = AddNode(name, op_type, description, input_args, output_args, attributes, domain); + const auto& annotation = annotation_source.GetLayeringAnnotation(); + if (!annotation.empty()) { + new_node.SetLayeringAnnotation(annotation); + } + return new_node; +} + +Node& Graph::AddNode(const std::string& name, + const std::string& op_type, + const std::string& description, + gsl::span input_args, + gsl::span output_args, + const Node& annotation_source, + NodeAttributes&& attributes, + const std::string& domain) { + auto& new_node = AddNode(name, op_type, description, input_args, output_args, std::move(attributes), domain); + const auto& annotation = annotation_source.GetLayeringAnnotation(); + if (!annotation.empty()) { + new_node.SetLayeringAnnotation(annotation); + } + return new_node; +} + bool Graph::RemoveNode(NodeIndex p_index) { auto node = GetNode(p_index); if (nullptr == node) { @@ -6074,7 +6134,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin return Status::OK(); } -Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { +Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline, + const std::string& parent_annotation) { auto to_node_arg = [this](const std::string& name) { return &this->GetOrCreateNodeArg(name, nullptr); }; @@ -6109,28 +6170,31 @@ Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_i for (const auto& node_attr : inlined_node->attribute()) { new_attr_map.insert_or_assign(node_attr.name(), node_attr); } - ORT_IGNORE_RETURN_VALUE(AddNode(inlined_node->name(), inlined_node->op_type(), - inlined_node->doc_string(), inputs, outputs, - &new_attr_map, inlined_node->domain())); + auto& new_node = AddNode(inlined_node->name(), inlined_node->op_type(), + inlined_node->doc_string(), inputs, outputs, + &new_attr_map, inlined_node->domain()); + + // Nodes that come from function_proto currently can not have any annotations. + // So we set it to parent. + if (!parent_annotation.empty()) { + new_node.SetLayeringAnnotation(parent_annotation); + } } return Status::OK(); } Status Graph::InlineFunction(Node& callnode) { - // Remove output edges. Requirement for RemoveNode() below. - auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator - for (const auto& output_edge : output_edges) { - RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), - output_edge.GetDstArgIndex()); - } - // create a uniq_identifier to append to every node name and intermediate input\outputs // to make sure there are no unintended duplicates std::string base_uniq_identifier{"_inlfunc_"}; base_uniq_identifier.append(callnode.OpType()); const auto uniq_identifier = GenerateNodeName(base_uniq_identifier); + // Capture the parent function node's layering annotation before inlining. + // Inlined nodes that don't already have their own annotation will inherit this. + const std::string parent_annotation = callnode.GetLayeringAnnotation(); + // Replace a (function-call) node by an inlined graph. if (!callnode.GetFunctionBody()) { // This is the normal use-case: inlining a FunctionProto (representing @@ -6142,7 +6206,7 @@ Status Graph::InlineFunction(Node& callnode) { function_utils::Specialize(inlined_fp, callnode, uniq_identifier); // In this case, global Resolve() will take care of everything. - ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp)); + ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp, parent_annotation)); } else { // Uncommon scenario. Inlining a node representing a fused sub-graph. // TODO: Unclear that this feature is needed. Can this be removed? @@ -6161,11 +6225,18 @@ Status Graph::InlineFunction(Node& callnode) { outputs.push_back(&n_output); } - AddNode(subgraph_node.Name() + uniq_identifier, subgraph_node.OpType(), subgraph_node.Description(), - inputs, - outputs, - &subgraph_node.GetAttributes(), - subgraph_node.Domain()); + auto& new_node = AddNode(subgraph_node.Name() + uniq_identifier, subgraph_node.OpType(), + subgraph_node.Description(), + inputs, + outputs, + &subgraph_node.GetAttributes(), + subgraph_node.Domain()); + if (!subgraph_node.GetLayeringAnnotation().empty()) { + new_node.SetLayeringAnnotation(subgraph_node.GetLayeringAnnotation()); + } else if (!parent_annotation.empty()) { + // If the subgraph node doesn't have its own annotation, use the parent function node's annotation. + new_node.SetLayeringAnnotation(parent_annotation); + } } } @@ -6192,9 +6263,15 @@ Status Graph::InlineFunction(Node& callnode) { } } - RemoveNode(callnode.Index()); + // Requirement for RemoveNode() below. + // copy so RemoveEdge doesn't invalidate iterator + auto output_edges = callnode.GetRelationships().output_edges; + for (const auto& output_edge : output_edges) { + RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), + output_edge.GetDstArgIndex()); + } - // std::cout << "Graph after inlining\n\n" << *this << std::endl << std::flush; + RemoveNode(callnode.Index()); return Status::OK(); } diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 0480263befdd1..85de654581161 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -32,6 +32,154 @@ static int GetIndexFromName(const Node& node, const std::string& name, bool is_i return static_cast(index); } +Status CreateFilteredIndexedGraph(gsl::span nodes, const Graph& graph, + std::unique_ptr& result) { + // Following data structures help determine the final inputs/outputs of the subgraph. + // Note: The 'subgraph' here refers to a graph that contains a subset of nodes in the 'src_graph'. + + // Pre-pass: Identify all outputs produced by nodes within the subgraph. + // This allows O(1) checks to determine if an input is internal or from the boundary. + InlinedHashSet node_set; + InlinedHashSet internal_outputs; + for (size_t i = 0, lim = nodes.size(); i < lim; i++) { + const auto& node = *nodes[i]; + node_set.insert(node.Index()); + for (const auto& output : node.OutputDefs()) { + internal_outputs.insert(output); + } + } + + // Source graph output names + InlinedHashSet graph_output_names; + for (const auto* output_arg : graph.GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } + + // These maps store the inputs and outputs of the subgraph. + // Value is order index to maintain deterministic order. + InlinedHashMap subgraph_inputs, subgraph_outputs; + + int input_order = 0; + int output_order = 0; + + std::unique_ptr indexed_sub_graph = std::make_unique(); + InlinedVector initializers; + + // Add nodes and identify boundary inputs/outputs + for (size_t i = 0, lim = nodes.size(); i < lim; i++) { + const auto& node = *nodes[i]; + indexed_sub_graph->nodes.push_back(node.Index()); + + // Process Inputs: If an input is not produced internally, it's a subgraph input. + auto process_inputs = [&](gsl::span inputs) { + for (const auto& input : inputs) { + if (!input->Exists()) continue; + + const auto* tensor_proto = graph.GetConstantInitializer(input->Name(), true); + if (tensor_proto != nullptr) { + initializers.push_back(input->Name()); + continue; + } + + // If not produced by this subgraph, it's a boundary input + if (internal_outputs.count(input) == 0) { + // Use insert to keep the first occurrence's order + auto emplace_result = subgraph_inputs.emplace(input, input_order); + if (emplace_result.second) { + ++input_order; + } + } + } + }; + + process_inputs(gsl::make_span(node.InputDefs().data(), node.InputDefs().size())); + process_inputs(gsl::make_span(node.ImplicitInputDefs().data(), node.ImplicitInputDefs().size())); + + // Process Outputs: If an output is graph output OR consumed externally, it's a subgraph output. + for (const auto& output : node.OutputDefs()) { + if (!output->Exists()) continue; + + bool is_boundary_output = false; + + // 1. Is it a graph output? + if (graph_output_names.count(output->Name()) > 0) { + is_boundary_output = true; + } else { + // 2. Is it consumed by any node outside the subgraph? + for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + // Check if the edge uses this specific output + if (it->GetSrcArgIndex() < static_cast(node.OutputDefs().size()) && + node.OutputDefs()[it->GetSrcArgIndex()] == output) { + if (node_set.count(it->GetNode().Index()) == 0) { + is_boundary_output = true; + break; + } + } + } + } + + if (is_boundary_output) { + subgraph_outputs.insert({output, output_order++}); + } + } + } + + std::multimap inputs, outputs; + + // Get the input order of the original graph + InlinedHashMap original_inputs; + int order = 0; + for (const auto* input : graph.GetInputs()) { + original_inputs[input] = order++; + } + + // input order needs to be consistent with original graph's input order + for (const auto& [node_arg, subgraph_input_order] : subgraph_inputs) { + const auto original_input_it = original_inputs.find(node_arg); + + if (original_input_it != original_inputs.end()) { + inputs.emplace( + original_input_it->second, // input order from original graph + node_arg); + } else { + inputs.emplace( + subgraph_input_order, // input order from subgraph + node_arg); + } + } + + // Sort outputs by the order they were added + for (const auto& [node_arg, subgraph_output_order] : subgraph_outputs) { + outputs.emplace(subgraph_output_order, node_arg); + } + + std::unique_ptr meta_def = std::make_unique(); + meta_def->name = "sub_graph"; + meta_def->since_version = 1; + + // Assign inputs and outputs to subgraph's meta_def + for (const auto& input : inputs) { + if (input.second->Exists()) { + meta_def->inputs.push_back(input.second->Name()); + } + } + + for (const auto& initializer : initializers) { + meta_def->constant_initializers.push_back(initializer); + } + + for (const auto& output : outputs) { + if (output.second->Exists()) { + meta_def->outputs.push_back(output.second->Name()); + } + } + + indexed_sub_graph->SetMetaDef(std::move(meta_def)); + result = std::move(indexed_sub_graph); + + return Status::OK(); +} + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) @@ -1010,6 +1158,5 @@ NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg) { } #endif // !defined(ORT_MINIMAL_BUILD) - } // namespace graph_utils } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 256a6fc81495d..2106da1a96327 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -475,5 +475,21 @@ NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg); #endif // !defined(ORT_MINIMAL_BUILD) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +/// +/// This function creates an indexed subgraph from a collection of nodes +/// using the graph instance. The IndexedSubgraph can then be used to create +/// a filtered GraphViewer instance that only contains the nodes in the collection. +/// +/// +/// +/// +/// +Status CreateFilteredIndexedGraph(gsl::span nodes, const Graph& graph, + std::unique_ptr& indexed_subgraph); + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + } // namespace graph_utils } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc index f9ae13808cf2c..f3956d5e9e0f3 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -605,7 +605,7 @@ void ApplyReshapeTransposeFusions( graph.GenerateNodeName("DQFusedMatMulNBits"), "MatMulNBits", "Fused from DQ+Reshape+Transpose+MatMul", - mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); + mnb_inputs, mnb_outputs, *mm_node, &mnb_attrs, kMSDomain); mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); @@ -784,7 +784,7 @@ void ApplyDirectDQFusions( graph.GenerateNodeName("DirectDQFusedMatMulNBits"), "MatMulNBits", "Fused from direct DQ(axis=0)+MatMul", - mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); + mnb_inputs, mnb_outputs, *mm_node, &mnb_attrs, kMSDomain); mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 9e35550e2f845..606e91ce91bbb 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -17,7 +17,7 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { // Add a Cast to convert Input from int64 to int32. -static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_type) { +static NodeArg* CastToInt32(Graph& graph, NodeArg* input, const Node& source_node) { auto data_type = input->TypeAsProto()->tensor_type().elem_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) { return input; @@ -36,13 +36,13 @@ static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_ "Cast Input from int64 to int32", std::array{input}, std::array{&cast32}, + source_node, nullptr, kOnnxDomain); // Add attribute: "to" = 6 node.AddAttribute("to", int64_t{ONNX_NAMESPACE::TensorProto_DataType_INT32}); - - node.SetExecutionProviderType(provider_type); + node.SetExecutionProviderType(source_node.GetExecutionProviderType()); return &cast32; } @@ -487,9 +487,9 @@ static void CreateEmbedLayernormNode(Graph& graph, NodeArg* segment_embedding, Node& layer_norm_node) { // Cast input_ids and segment_ids to int32 if needed. - input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType()); + input_ids = CastToInt32(graph, input_ids, layer_norm_node); if (segment_ids != nullptr && segment_embedding != nullptr) { - segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType()); + segment_ids = CastToInt32(graph, segment_ids, layer_norm_node); } NodeArg place_holder("", nullptr); @@ -514,7 +514,7 @@ static void CreateEmbedLayernormNode(Graph& graph, "fused EmbedLayerNorm subgraphs ", embed_layer_norm_input_defs, std::array{layer_norm_node.MutableOutputDefs()[0], &mask_index}, - {}, kMSDomain); + layer_norm_node, nullptr, kMSDomain); // Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value // will be used. diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index 641bfbf388623..e2f448bf70734 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -178,7 +178,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons "Gelu", "fused Gelu subgraphs ", gelu_input_defs, - {}, {}, op_domain); + {}, div, nullptr, op_domain); // Assign provider to this new node. Provider should be same as the provider for old node. gelu_node.SetExecutionProviderType(div.GetExecutionProviderType()); diff --git a/onnxruntime/core/optimizer/gemm_sum_fusion.cc b/onnxruntime/core/optimizer/gemm_sum_fusion.cc index be3c90a822fe2..c84e34a6d0dbe 100644 --- a/onnxruntime/core/optimizer/gemm_sum_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_sum_fusion.cc @@ -41,7 +41,8 @@ Status GemmSumFusion::Apply(Graph& graph, Node& gemm_node, RewriteRuleEffect& mo "Fused Gemm with Sum", new_gemm_input_defs, new_gemm_output_defs, - {}, + gemm_node, + nullptr, gemm_node.Domain()); new_gemm_node.AddAttribute("transA", static_cast(transA)); new_gemm_node.AddAttribute("transB", static_cast(transB)); diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index da454b67aecf4..a66ad987cfaef 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -80,7 +80,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m "Fused Gemm with Transpose", new_gemm_input_defs, {}, - {}, + gemm_node, + nullptr, gemm_node.Domain()); new_gemm_node.AddAttribute("transA", static_cast(transA)); new_gemm_node.AddAttribute("transB", static_cast(transB)); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 3ade3864255ea..c10e070ef8f09 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -474,7 +474,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, "LayerNormalization", "fused LayerNorm subgraphs ", layer_norm_input_defs, - {}, {}, kOnnxDomain); + {}, mul_node, nullptr, kOnnxDomain); // Get constant "epsilon" from "Add2" node if available. Else, default value will be used. const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add2_node.MutableInputDefs()[1]->Name()); @@ -719,7 +719,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr InlinedVector layer_norm_input_defs{x_input, scale}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/SimplifiedLayerNormFusion/"), "SimplifiedLayerNormalization", - "fused LayerNorm subgraphs ", layer_norm_input_defs, {}, {}, kOnnxDomain); + "fused LayerNorm subgraphs ", layer_norm_input_defs, {}, mul_node, nullptr, kOnnxDomain); // Get constant "epsilon" from "Add" node if available. Else, default value will be used. const ONNX_NAMESPACE::TensorProto* tensor_proto = diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 5db61877811aa..f567609c979a9 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -7,6 +7,7 @@ #include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/utils.h" #include #include @@ -204,7 +205,8 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* new_arg = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name + "_reshape_arg"), &new_arg_type); Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_reshape"), "Reshape", "Reshape for " + name, {is_input ? gemm_input_defs[0] : new_arg, shape_arg}, - {is_input ? new_arg : gemm_output_defs[0]}); + {is_input ? new_arg : gemm_output_defs[0]}, + matmul_node); reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); return &reshape_node; }; @@ -217,7 +219,8 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion"), "Gemm", - "fused Matmul and Add", gemm_input_defs, gemm_output_defs); + "fused Matmul and Add", gemm_input_defs, gemm_output_defs, + matmul_node); gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); if (need_reshape) { diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 871571ea64881..be52e26a2901f 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -227,6 +227,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& "Generated from Matmul BatchNormalization fusion", {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, matmul_node.MutableOutputDefs(), + matmul_node, nullptr, kOnnxDomain); diff --git a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc index 9d53e28921784..c79e4142a9ee2 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc @@ -10,6 +10,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/optimizer/utils.h" namespace onnxruntime { @@ -53,6 +54,7 @@ Status DuplicateDQForOutputEdge(const graph_utils::GraphEdge& original_dq_output MakeString("Added by ", kTransformerName), dq_inputs, {&new_dq_output_nodearg}, + original_dq_node, &original_dq_node.GetAttributes(), original_dq_node.Domain()); diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index b8252bc7a75b4..0d732a71b7ed0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -194,6 +194,8 @@ Status InsertQDQPairs(Graph& graph, gsl::span insertion } } + optimizer_utils::DuplicateNodeAnnotation(*src_node, q_node); + // Add edge from src to Q node. src_node->MutableOutputDefs()[first_edge.src->arg_idx] = &pre_q_nodearg; graph.AddEdge(src_node->Index(), q_node.Index(), first_edge.src->arg_idx, 0); @@ -221,6 +223,10 @@ Status InsertQDQPairs(Graph& graph, gsl::span insertion &dq_attrs, // attributes qdq_domain); + if (src_node) { + optimizer_utils::DuplicateNodeAnnotation(*src_node, dq_node); + } + ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); Node* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 5a6eb82c3e6c0..ba3ea09564c17 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -189,14 +189,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_q"), &weight_q_type_proto); Node& weight_q_node = graph.AddNode( graph.GenerateNodeArgName(node.Name() + "_weight_q"), QDQ::QOpName, "Weight Q node", - {node.MutableInputDefs()[1], weight_scale_arg, &weight_zp_arg}, {&weight_q_arg}, nullptr, node.Domain()); + {node.MutableInputDefs()[1], weight_scale_arg, &weight_zp_arg}, {&weight_q_arg}, node, nullptr, node.Domain()); // DQ from int8 to float32. NodeArg& weight_dq_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), weight_arg->TypeAsProto()); Node& weight_dq_node = graph.AddNode(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), QDQ::DQOpName, "Weight DQ node", - {&weight_q_arg, weight_scale_arg, &weight_zp_arg}, {&weight_dq_arg}, nullptr, node.Domain()); + {&weight_q_arg, weight_scale_arg, &weight_zp_arg}, {&weight_dq_arg}, node, nullptr, node.Domain()); graph.AddEdge(weight_q_node.Index(), weight_dq_node.Index(), 0, 0); node.MutableInputDefs()[1] = &weight_dq_arg; graph.AddEdge(weight_dq_node.Index(), node.Index(), 0, 1); @@ -211,14 +211,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph weight_scale_arg->TypeAsProto()); Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Bias scale node", - {dq_0.MutableInputDefs()[1], weight_scale_arg}, {&bias_scale_arg}, nullptr, node.Domain()); + {dq_0.MutableInputDefs()[1], weight_scale_arg}, {&bias_scale_arg}, node, nullptr, node.Domain()); // fp_bias / scale. NodeArg& bias_div_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), bias_arg->TypeAsProto()); Node& div_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node", - {node.MutableInputDefs()[2], &bias_scale_arg}, {&bias_div_arg}, nullptr, node.Domain()); + {node.MutableInputDefs()[2], &bias_scale_arg}, {&bias_div_arg}, node, nullptr, node.Domain()); graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1); // Round(fp_bias / scale). @@ -226,7 +226,7 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), bias_arg->TypeAsProto()); Node& round_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node", - {&bias_div_arg}, {&bias_div_round_arg}, nullptr, node.Domain()); + {&bias_div_arg}, {&bias_div_round_arg}, node, nullptr, node.Domain()); graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0); // Cast(Round(fp_bias / scale)) to int32. @@ -236,7 +236,7 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph NodeArg& bias_int32_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &bias_int32_type_proto); Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias INT32 node", - {&bias_div_round_arg}, {&bias_int32_arg}, nullptr, node.Domain()); + {&bias_div_round_arg}, {&bias_int32_arg}, node, nullptr, node.Domain()); cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32)); graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0); @@ -245,7 +245,7 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), bias_arg->TypeAsProto()); Node& bias_dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node", - {&bias_int32_arg, &bias_scale_arg}, {&bias_dq_arg}, nullptr, node.Domain()); + {&bias_int32_arg, &bias_scale_arg}, {&bias_dq_arg}, node, nullptr, node.Domain()); if (!is_per_tensor_scale) { bias_dq_node.AddAttribute("axis", static_cast(0)); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc index 9bd91e7916ecb..94fc7f6c03fa1 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc @@ -134,6 +134,7 @@ Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, con "DeQuantizeLinear from WhereDummyDq GraphTransformer", {&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg}, {&dummy_dq_arg}, + node, nullptr, dq_node->Domain()); diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 6a2b4295093d8..167952356ff58 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -495,7 +495,8 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { NodeArg* shape_arg = &graph_utils::AddInitializerWithOrtValue(graph, shape_initializer_proto); Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name, {contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg}, - {contiguous_reshapes.back().get().MutableOutputDefs()[0]}); + {contiguous_reshapes.back().get().MutableOutputDefs()[0]}, + reshape); reshape_node.SetExecutionProviderType(contiguous_reshapes[0].get().GetExecutionProviderType()); graph_utils::FinalizeNodeFusion(graph, contiguous_reshapes, reshape_node); diff --git a/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc b/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc index f72f74e3b4a5c..8caea2c150990 100644 --- a/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc +++ b/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc @@ -492,6 +492,7 @@ bool FuseSliceConcatToSpaceToDepth(Node& concat, Graph& graph, const logging::Lo : "Fused Slice*4 + Concat into SpaceToDepth + channel permutation", {space_to_depth_input}, space_to_depth_outputs, + concat, nullptr, kOnnxDomain); space_to_depth.AddAttribute("blocksize", kBlockSize); @@ -517,6 +518,7 @@ bool FuseSliceConcatToSpaceToDepth(Node& concat, Graph& graph, const logging::Lo "Reorder SpaceToDepth channels to preserve Slice+Concat block order", {space_to_depth.MutableOutputDefs()[0], gather_indices_arg}, {}, + concat, nullptr, kOnnxDomain); gather.AddAttribute("axis", static_cast(kChannelAxis)); diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc index 60ab064465f2f..c84e60e64bd2d 100644 --- a/onnxruntime/core/optimizer/stft_decomposition.cc +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -58,27 +58,43 @@ NodeArg* AddShapeInitializer(Graph& graph, const char* name, const int64_t (&sha std::pair AddNode(Graph& graph, const char* op_type, ProviderType execution_provider_type, - gsl::span inputs) { + gsl::span inputs, + const Node* annotation_source = nullptr) { auto def_name = graph.GenerateNodeArgName(op_type); auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr); - Node& node = graph.AddNode(graph.GenerateNodeName(op_type), - op_type, - "", - inputs, - {node_arg}); + Node& node = annotation_source + ? graph.AddNode(graph.GenerateNodeName(op_type), + op_type, + "", + inputs, + {node_arg}, + *annotation_source) + : graph.AddNode(graph.GenerateNodeName(op_type), + op_type, + "", + inputs, + {node_arg}); node.SetExecutionProviderType(execution_provider_type); return std::make_pair(&node, node_arg); } std::pair AddNodeCast(Graph& graph, NodeArg* in, - ONNX_NAMESPACE::TensorProto_DataType data_type) { + ONNX_NAMESPACE::TensorProto_DataType data_type, + const Node* annotation_source = nullptr) { auto def_name = graph.GenerateNodeArgName("Cast"); auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr); - Node& node = graph.AddNode(graph.GenerateNodeName("Cast"), - "Cast", - "", - {in}, - {node_arg}); + Node& node = annotation_source + ? graph.AddNode(graph.GenerateNodeName("Cast"), + "Cast", + "", + {in}, + {node_arg}, + *annotation_source) + : graph.AddNode(graph.GenerateNodeName("Cast"), + "Cast", + "", + {in}, + {node_arg}); node.AddAttribute("to", static_cast(data_type)); node.SetExecutionProviderType(kCpuExecutionProvider); return std::make_pair(&node, node_arg); @@ -238,7 +254,7 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve Node* reshape_signal_node = nullptr; NodeArg* reshape_output = nullptr; std::tie(reshape_signal_node, reshape_output) = - AddNode(graph, "Reshape", stft.GetExecutionProviderType(), signal_reshaped_inputs); + AddNode(graph, "Reshape", stft.GetExecutionProviderType(), signal_reshaped_inputs, &stft); NodeArg* real_weights_final = real_weights; NodeArg* imag_weights_final = imaginary_weights; @@ -246,11 +262,11 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve // When we are missing a window function if (real_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) { std::tie(std::ignore, real_weights_final) = - AddNodeCast(graph, real_weights_final, data_type); + AddNodeCast(graph, real_weights_final, data_type, &stft); } if (imag_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) { std::tie(std::ignore, imag_weights_final) = - AddNodeCast(graph, imag_weights_final, data_type); + AddNodeCast(graph, imag_weights_final, data_type, &stft); } } else { // When we have a window function @@ -261,7 +277,7 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve if (window->TypeAsProto()->tensor_type().elem_type() != GetDataType()) { Node* window_cast_node = nullptr; std::tie(window_cast_node, window_final) = - AddNodeCast(graph, window, GetDataType()); + AddNodeCast(graph, window, GetDataType(), &stft); window_recipient = window_cast_node; } @@ -269,7 +285,7 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve Node* window_reshape_node; NodeArg* window_reshaped = nullptr; std::tie(window_reshape_node, window_reshaped) = - AddNode(graph, "Reshape", kCpuExecutionProvider, window_reshaped_inputs); + AddNode(graph, "Reshape", kCpuExecutionProvider, window_reshaped_inputs, &stft); if (!window_recipient) { window_recipient = window_reshape_node; } @@ -277,17 +293,17 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve NodeArg* scale_real_weights_inputs[] = {real_weights, window_reshaped}; NodeArg* windowed_real_weights_output = nullptr; std::tie(std::ignore, windowed_real_weights_output) = - AddNode(graph, "Mul", kCpuExecutionProvider, scale_real_weights_inputs); + AddNode(graph, "Mul", kCpuExecutionProvider, scale_real_weights_inputs, &stft); NodeArg* scale_imag_weights_inputs[] = {imaginary_weights, window_reshaped}; NodeArg* windowed_imag_weights_output = nullptr; std::tie(std::ignore, windowed_imag_weights_output) = - AddNode(graph, "Mul", kCpuExecutionProvider, scale_imag_weights_inputs); + AddNode(graph, "Mul", kCpuExecutionProvider, scale_imag_weights_inputs, &stft); std::tie(std::ignore, real_weights_final) = - AddNodeCast(graph, windowed_real_weights_output, data_type); + AddNodeCast(graph, windowed_real_weights_output, data_type, &stft); std::tie(std::ignore, imag_weights_final) = - AddNodeCast(graph, windowed_imag_weights_output, data_type); + AddNodeCast(graph, windowed_imag_weights_output, data_type, &stft); } // Add Convolution (reals) @@ -295,7 +311,7 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve Node* real_conv_node = nullptr; NodeArg* real_conv_output = nullptr; std::tie(real_conv_node, real_conv_output) = - AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_real_inputs); + AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_real_inputs, &stft); real_conv_node->AddAttribute("strides", std::vector{1, frame_step_value}); // Add Convolution (imaginary) @@ -303,7 +319,7 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve Node* imag_conv_node = nullptr; NodeArg* imag_conv_output = nullptr; std::tie(imag_conv_node, imag_conv_output) = - AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_imag_inputs); + AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_imag_inputs, &stft); imag_conv_node->AddAttribute("strides", std::vector{1, frame_step_value}); // Concatenate @@ -311,21 +327,21 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve Node* concat_node = nullptr; NodeArg* concatenated_conv_output = nullptr; std::tie(concat_node, concatenated_conv_output) = - AddNode(graph, "Concat", stft.GetExecutionProviderType(), concatenate_inputs); + AddNode(graph, "Concat", stft.GetExecutionProviderType(), concatenate_inputs, &stft); concat_node->AddAttribute("axis", static_cast(0)); // Unsqueeze Reshape NodeArg* unsqueeze_reshape_inputs[] = {concatenated_conv_output, unsqueezed_shape}; NodeArg* unsqueezed_output = nullptr; std::tie(std::ignore, unsqueezed_output) = - AddNode(graph, "Reshape", stft.GetExecutionProviderType(), unsqueeze_reshape_inputs); + AddNode(graph, "Reshape", stft.GetExecutionProviderType(), unsqueeze_reshape_inputs, &stft); // Transpose NodeArg* transpose_inputs[] = {unsqueezed_output}; Node* transpose_node = nullptr; NodeArg* transpose_output = nullptr; std::tie(transpose_node, transpose_output) = - AddNode(graph, "Transpose", stft.GetExecutionProviderType(), transpose_inputs); + AddNode(graph, "Transpose", stft.GetExecutionProviderType(), transpose_inputs, &stft); transpose_node->AddAttribute("perm", std::vector{1, 3, 2, 0}); signal_recipient = reshape_signal_node; diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 29b603da56e29..467d0c090070f 100755 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -531,6 +531,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { // Add Q auto new_q_node = MakeQuantizeOp(graph, dq_domain, inputs, axis, dq_node.GetAttributeInt("block_size"), dq_node.GetAttributeInt("output_dtype"), dq_node.GetAttributeInt("saturate")); + new_q_node->SetLayeringAnnotation(dq_node.GetLayeringAnnotation()); auto q_node_outputs = new_q_node->Outputs(); // copy value info from the dq input for the type information, and update the shape to match next_node's output @@ -543,6 +544,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { // Add DQ auto new_dq_node = MakeDequantizeOp(graph, dq_domain, inputs, axis, dq_node.GetAttributeInt("block_size")); + new_dq_node->SetLayeringAnnotation(dq_node.GetLayeringAnnotation()); auto dq_node_outputs = new_dq_node->Outputs(); // straight copy of value info as the type and shape are the same as next_node's output @@ -1007,6 +1009,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // (see Case 2). if (consumers->nodes.size() > 0) { auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", value_to_modify, axes); + squeeze_ptr->SetLayeringAnnotation(node.GetLayeringAnnotation()); api::NodeRef& squeeze = *squeeze_ptr; std::string_view sq_out = squeeze.Outputs()[0]; ctx.graph.CopyValueInfo(value_to_modify, sq_out); @@ -1075,6 +1078,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // Case 3: Add an Unsqueeze node. auto unsqueeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Unsqueeze", input, axes); + unsqueeze_ptr->SetLayeringAnnotation(node.GetLayeringAnnotation()); api::NodeRef& unsqueeze = *unsqueeze_ptr; std::string_view unsq_out = unsqueeze.Outputs()[0]; ctx.graph.CopyValueInfo(input, unsq_out); @@ -1207,6 +1211,7 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. auto transpose_inv_ptr = MakeTranspose(graph, constant_to_modify, perm_inv); + transpose_inv_ptr->SetLayeringAnnotation(node.GetLayeringAnnotation()); api::NodeRef& transpose_inv = *transpose_inv_ptr; std::string_view transpose_out = transpose_inv.Outputs()[0]; graph.CopyValueInfo(constant_to_modify, transpose_out); @@ -1267,6 +1272,7 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // the other Transpose. const std::vector& perm_combined = ComposePerm(*perm2, perm); auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); + transpose_ptr->SetLayeringAnnotation(node.GetLayeringAnnotation()); api::NodeRef& transpose = *transpose_ptr; std::string_view transpose_out = transpose.Outputs()[0]; graph.CopyValueInfo(input, transpose_out); @@ -1301,6 +1307,7 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // Case 4: Add a new Transpose op auto transpose_ptr = MakeTranspose(graph, input, perm); + transpose_ptr->SetLayeringAnnotation(node.GetLayeringAnnotation()); api::NodeRef& transpose = *transpose_ptr; std::string_view transpose_out = transpose.Outputs()[0]; graph.CopyValueInfo(input, transpose_out); @@ -1376,6 +1383,7 @@ std::string_view TransposeOutput(api::GraphRef& graph, api::NodeRef& node, size_ // X -> Node -> Y, Transpose auto transpose = MakeTranspose(graph, "", perm); + transpose->SetLayeringAnnotation(node.GetLayeringAnnotation()); // X -> Node -> *Y', Transpose -> Y *shape/dtype not set graph.MoveOutput(node, i, *transpose, 0); @@ -1730,6 +1738,7 @@ static bool HandleShape(HandlerArgs& args) { // X -> Shape -> Y, Gather std::vector gather_inputs{"", perm_const}; auto gather_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); + gather_ptr->SetLayeringAnnotation(args.node.GetLayeringAnnotation()); api::NodeRef& gather = *gather_ptr; gather.SetAttributeInt("axis", 0); @@ -1773,6 +1782,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); std::vector gather_inputs{input_name, gather_indices_const}; auto gather_ptr = graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); + gather_ptr->SetLayeringAnnotation(node.GetLayeringAnnotation()); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; graph.CopyValueInfo(input_name, gather_output); @@ -2221,6 +2231,7 @@ static bool HandleTile(HandlerArgs& args) { std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv); std::vector gather_inputs{repeats_inp, perm_inv_const}; auto gather_node_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); + gather_node_ptr->SetLayeringAnnotation(args.node.GetLayeringAnnotation()); api::NodeRef& gather_node = *gather_node_ptr; std::string_view gather_output = gather_node.Outputs()[0]; args.ctx.graph.CopyValueInfo(repeats_inp, gather_output); @@ -2271,6 +2282,7 @@ static void RemoveCancelingTransposeNodes(HandlerArgs& args) { // despite computing the same value. Use an Identity op instead. std::vector single_empty_input{""}; auto identity_ptr = args.ctx.graph.AddNode("Identity", "Identity", single_empty_input, /*num_outputs*/ 1); + identity_ptr->SetLayeringAnnotation(args.node.GetLayeringAnnotation()); api::NodeRef& identity = *identity_ptr; args.ctx.graph.MoveOutput(args.node, 0, identity, 0); identity.SetInput(0, transpose_input); @@ -2303,6 +2315,7 @@ static bool HandleTransposeImpl(HandlerArgs& args, const std::vector& n // use the same input as the 1st Transpose, move the output from the Reshape to the new Transpose node, // and remove the Reshape node. new_node = args.ctx.graph.AddNode("Transpose", "Transpose", {args.transpose.Inputs()[0]}, 1); + new_node->SetLayeringAnnotation(args.node.GetLayeringAnnotation()); args.ctx.graph.MoveOutput(args.node, 0, *new_node, 0); args.ctx.graph.RemoveNode(args.node); } else { @@ -2973,6 +2986,7 @@ static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_ // Add Q auto new_q_node = MakeQuantizeOp(ctx.graph, q_domain, inputs, axis, q_node.GetAttributeInt("block_size"), q_node.GetAttributeInt("output_dtype"), q_node.GetAttributeInt("saturate")); + new_q_node->SetLayeringAnnotation(transpose_node.GetLayeringAnnotation()); auto new_q_node_output = new_q_node->Outputs()[0]; // Copy value info from the q output for the type information, and update the shape to match Transpose's input @@ -2985,6 +2999,7 @@ static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_ // Add new DQ. auto new_dq_node = MakeDequantizeOp(ctx.graph, q_domain, inputs, axis, q_node.GetAttributeInt("block_size")); + new_dq_node->SetLayeringAnnotation(transpose_node.GetLayeringAnnotation()); auto new_dq_node_output = new_dq_node->Outputs()[0]; ctx.graph.CopyValueInfo(transpose_input_name, new_dq_node_output); diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 6ff4da05fbf57..4ee5a65b9b9fb 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -258,6 +258,18 @@ class NodeRef { /// Id virtual int64_t Id() const = 0; + /// + /// Get the layering annotation of the node. + /// + /// annotation + virtual std::string_view GetLayeringAnnotation() const = 0; + + /// + /// Set layering annotation + /// + /// + virtual void SetLayeringAnnotation(std::string_view annotation) = 0; + virtual ~NodeRef() {}; }; diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 6a02ca3578da2..5d5ed663cca05 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -105,6 +105,14 @@ class ApiNode final : public api::NodeRef { int SinceVersion() const override; int64_t Id() const override; + std::string_view GetLayeringAnnotation() const override { + return node_.GetLayeringAnnotation(); + } + + void SetLayeringAnnotation(std::string_view annotation) override { + node_.SetLayeringAnnotation(std::string(annotation)); + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode); }; @@ -763,6 +771,9 @@ std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node source_node.Outputs().size(), domain, new_node_since_version, source_node.GetExecutionProviderType()); + const auto& layering_annotation = source_node.GetLayeringAnnotation(); + node.SetLayeringAnnotation(std::string(layering_annotation)); + std::unique_ptr new_node = std::make_unique(node, graph_); new_node->CopyAttributes(source_node); diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 4a323eefe1fe7..6d40b389d5fa3 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -495,6 +495,13 @@ bool IsScalar(const NodeArg& input_arg) { return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1); } +void DuplicateNodeAnnotation(const Node& src, Node& dst) { + const auto& src_annotation = src.GetLayeringAnnotation(); + if (!src_annotation.empty()) { + dst.SetLayeringAnnotation(src_annotation); + } +} + template bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value, bool is_constant) { diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index 857640f861238..2f9b48df7a75f 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -175,6 +175,8 @@ bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_outp // Check if NodeArg takes in a scalar tensor. bool IsScalar(const NodeArg& input_arg); +void DuplicateNodeAnnotation(const Node& src, Node& dst); + #endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } // namespace optimizer_utils diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4c735fa2d5650..c6354b1c533cd 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -3106,16 +3106,20 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, } auto threshold = resource_accountant->GetThreshold(); - if (!threshold.has_value()) { + if (!threshold) { // info_.gpu_mem_limit is for BFC arena size_t free_memory, total_memory; if (0 != cudaMemGetInfo(&free_memory, &total_memory)) { memory_threshold = info_.gpu_mem_limit; + LOGS(logger, INFO) + << "CUDA_EP failed to get available GPU memory info. Using info_.gpu_mem_limit instead: " << info_.gpu_mem_limit; } else { memory_threshold = std::min(free_memory, info_.gpu_mem_limit); + LOGS(logger, VERBOSE) + << "CUDA_EP Using threshold: " << memory_threshold << " Free memory reported: " << free_memory; } } else { - memory_threshold = std::get<0>(threshold.value()); + memory_threshold = std::get<0>(*threshold); } consumed_memory = std::get<0>(resource_accountant->GetConsumedAmount()); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b873c95b496bb..2ba52a3e989bd 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -29,6 +29,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/kernel_type_str_resolver.h" #include "core/framework/kernel_type_str_resolver_utils.h" +#include "core/framework/layering_annotations.h" #include "core/framework/mldata_type_utils.h" #include "core/framework/TensorSeq.h" #include "core/framework/tensorprotoutils.h" @@ -1518,11 +1519,33 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool } } + LayeringIndex* layering_index = nullptr; +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + std::optional layering_index_storage; + const auto layering_config = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsLayerAssignmentSettings, ""); + if (!layering_config.empty()) { + ORT_RETURN_IF_ERROR_SESSIONID_(LayeringIndex::Create(graph, layering_config, {}, execution_providers_, + *session_logger_, layering_index_storage)); + if (layering_index_storage) { + layering_index = &layering_index_storage.value(); + } + } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // Do partitioning based on execution providers' capabilities. ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, - session_options_.config_options, *session_logger_, + session_options_.config_options, *session_logger_, layering_index, mode, session_options_.GetEpContextGenerationOptions(), debug_graph_fn)); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + if (layering_index) { + // Layering annotations maybe present even if index is not built although unlikely. + ORT_RETURN_IF_ERROR_SESSIONID_(graph.RemoveAllLayeringAnnotations()); + // We are currently not using it beyond this point. Clear it to free up memory. + layering_index = nullptr; + layering_index_storage.reset(); + } +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // Get graph optimizations loop level from session config, if not present, set to default value of 1 as per // the definition of kOrtSessionOptionsGraphOptimizationsLoopLevel. unsigned int graph_optimizations_loop_level = static_cast(std::stoi( @@ -2039,6 +2062,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, transform_layout_fn, sess_options.config_options, logger, + nullptr /*layering_index*/, GraphPartitioner::Mode::kOrtFormatLoad)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 37a74a5de22a6..9834902cea2b1 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3031,7 +3031,23 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, "src_graph is a ModelEditorGraph which doesn't support Graph_GetGraphView."); } const GraphViewer& graph_viewer = ep_graph->GetGraphViewer(); - const Graph& graph = graph_viewer.GetGraph(); + + // Create subgraph's node set and convert them to internal Node + InlinedHashSet node_set; + InlinedVector internal_nodes; + internal_nodes.reserve(num_nodes); + for (size_t i = 0; i < num_nodes; i++) { + const EpNode* ep_node = EpNode::ToInternal(nodes[i]); + if (ep_node != nullptr) { + const Node& node = ep_node->GetInternalNode(); + node_set.insert(node.Index()); + internal_nodes.push_back(&node); + } else { + std::ostringstream oss; + oss << "node indexed [" << i << "] appears to be a ModelEditorNode"; + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + } // Create a GraphViewer with filtered info // TODO: Investigate whether utils::MakeComputeCapability can be extended and reused instead @@ -3040,178 +3056,93 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, // Following data structures help determine the final inputs/outputs of the subgraph. // Note: The 'subgraph' here refers to a graph contains a subset of nodes in the 'src_graph'. - // Subgraph's node set - const std::unordered_set node_set = [&]() { - std::unordered_set node_set; - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - const EpNode* ep_node = EpNode::ToInternal(ort_node); - if (ep_node != nullptr) { - node_set.insert(ep_node->GetInternalNode().Index()); - } + // Pre-pass: Identify all outputs produced by nodes within the subgraph. + // This allows O(1) checks to determine if an input is internal or from the boundary. + InlinedHashSet internal_outputs; + for (size_t i = 0, lim = internal_nodes.size(); i < lim; i++) { + const auto& node = *internal_nodes[i]; + for (const auto& output : node.OutputDefs()) { + internal_outputs.insert(output); } - - return node_set; - }(); + } // Source graph output names - std::unordered_set graph_output_names; + InlinedHashSet graph_output_names; for (const auto* output_arg : graph_viewer.GetOutputs()) { graph_output_names.insert(output_arg->Name()); } // These maps store the inputs and outputs of the subgraph. - // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration - // to determine the final inputs and outputs of the subgraph. - std::unordered_map subgraph_inputs, subgraph_outputs; - - // This map stores the node's output that will be consumed by another node outside of this subgraph. - // So the node's output should be put into the subgraph's output list. - std::unordered_map subgraph_outputs_to_add; - - // This map stores the node's output that is original graph's output. - // So the node's output should be put into the subgraph's output list. - std::unordered_map graph_outputs_to_add; + // Value is order index to maintain deterministic order. + InlinedHashMap subgraph_inputs, subgraph_outputs; - std::unordered_set erased; - - // This is the relative ordering that ensures node's input or output being added to the 'subgraph_inputs', - // 'subgraph_outputs', 'subgraph_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. - // Items added earlier receive a smaller order index than items added later. - // When constructing the final subgraph's input or output lists, entries with smaller - // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; - // node arg to its consumer nodes. - // Note: graph.GetConsumerNodes() is not available in minimal build, in order to use unified implementation across - // all builds, this map is needed to determine if node arg is consumed by other nodes. - std::unordered_map> node_arg_to_consumer_nodes; - - std::vector initializers; + InlinedVector initializers; - // Add nodes - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - const EpNode* ep_node = EpNode::ToInternal(ort_node); - if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, - "node is a ModelEditorNode which doesn't support Graph_GetGraphView."); - } - const Node& node = ep_node->GetInternalNode(); + // Add nodes and identify boundary inputs/outputs + for (size_t i = 0, lim = internal_nodes.size(); i < lim; i++) { + const auto& node = *internal_nodes[i]; indexed_sub_graph->nodes.push_back(node.Index()); - for (const auto& input : node.InputDefs()) { - if (!input->Exists()) { - continue; - } + // Process Inputs: If an input is not produced internally, it's a subgraph input. + auto process_inputs = [&](gsl::span inputs) { + for (const auto& input : inputs) { + if (!input->Exists()) continue; - if (graph_viewer.IsConstantInitializer(input->Name(), true)) { - initializers.push_back(input->Name()); - continue; - } - const auto& it = subgraph_outputs.find(input); - if (it != subgraph_outputs.end()) { - subgraph_outputs.erase(it); - erased.insert(input); - } else if (erased.find(input) == erased.end()) { - // Only when input is neither in output list nor erased list, add the input to input list - subgraph_inputs.insert({input, input_order++}); - } - } + if (graph_viewer.IsConstantInitializer(input->Name(), true)) { + initializers.push_back(input->Name()); + continue; + } - for (const auto& input : node.ImplicitInputDefs()) { - if (!input->Exists()) { - continue; + // If not produced by this subgraph, it's a boundary input + if (internal_outputs.count(input) == 0) { + // Use insert to keep the first occurrence's order + auto p = subgraph_inputs.emplace(input, input_order); + if (p.second) { + input_order++; + } + } } + }; - if (graph_viewer.IsConstantInitializer(input->Name(), true)) { - initializers.push_back(input->Name()); - continue; - } - const auto& it = subgraph_outputs.find(input); - if (it != subgraph_outputs.end()) { - subgraph_outputs.erase(it); - erased.insert(input); - } else if (erased.find(input) == erased.end()) { - // Only when input is neither in output list nor erased list, add the input to input list - subgraph_inputs.insert({input, input_order++}); - } - } + process_inputs(gsl::make_span(node.InputDefs().data(), node.InputDefs().size())); + process_inputs(gsl::make_span(node.ImplicitInputDefs().data(), node.ImplicitInputDefs().size())); - // For output searching, there are two special cases, - // One is, if subgraph's node output is parent graph's output. the node output should - // be also added to the subgraph's output list - // The other one is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, - // if the output is connected to nodes that don't belong to the subgraph, the output need to be added - // to the output list + // Process Outputs: If an output is graph output OR consumed externally, it's a subgraph output. for (const auto& output : node.OutputDefs()) { - if (!output->Exists()) { - continue; - } + if (!output->Exists()) continue; + + bool is_boundary_output = false; - const auto& it = subgraph_inputs.find(output); - if (it != subgraph_inputs.end()) { - subgraph_inputs.erase(it); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - auto has_consumer_nodes = [&](const std::string& node_arg_str) -> bool { - // Same implementation as Graph::PopulateNodeArgToProducerConsumerLookupsFromNodes() - if (node_arg_to_consumer_nodes.empty()) { - for (const auto& node : graph.Nodes()) { - node.ForEachDef([&](const NodeArg& node_arg, bool is_input) { - if (is_input) { - node_arg_to_consumer_nodes[node_arg.Name()].insert(node.Index()); - } - }); + // 1. Is it a graph output? + if (graph_output_names.count(output->Name()) > 0) { + is_boundary_output = true; + } else { + // 2. Is it consumed by any node outside the subgraph? + for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + // Check if the edge uses this specific output + if (it->GetSrcArgIndex() < static_cast(node.OutputDefs().size()) && + node.OutputDefs()[it->GetSrcArgIndex()] == output) { + if (node_set.count(it->GetNode().Index()) == 0) { + is_boundary_output = true; + break; } } - return node_arg_to_consumer_nodes.find(node_arg_str) != node_arg_to_consumer_nodes.end(); - }; - - if (has_consumer_nodes(output->Name())) { - // Only when output is neither in input list nor erased list, - // and the output is consumed by another node, add the output to output list - subgraph_outputs.insert({output, output_order++}); } } - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - // This output is the graph's output. - // So the output should be put into the subgraph's output list. - graph_outputs_to_add.insert({output, output_order++}); - } - } - - if (node.GetOutputEdgesCount() > node.OutputDefs().size()) { - for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { - const auto& node_idx = it->GetNode().Index(); - - if (node_set.find(node_idx) == node_set.end()) { - // This output will be consumed by another node outside of this subgraph. - // So the output should be put into the subgraph's output list. - const NodeArg* output = nullptr; - - // The dst_arg_index from GetDstArgIndex() could be the index for explicit/implicit input defs of the node. - // We need to get the correct input index accordingly. (See Graph::BuildConnections() in graph.cc for more details) - if (it->GetDstArgIndex() < static_cast(it->GetNode().InputDefs().size())) { - output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; - } else { - output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - it->GetNode().InputDefs().size()]; - } - subgraph_outputs_to_add.insert({output, output_order++}); - } + if (is_boundary_output) { + subgraph_outputs.insert({output, output_order++}); } } } - subgraph_outputs.insert(subgraph_outputs_to_add.begin(), subgraph_outputs_to_add.end()); - subgraph_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); - std::multimap inputs, outputs; // Get the input order of the original graph - std::unordered_map original_inputs; + InlinedHashMap original_inputs; int order = 0; for (const auto* input : graph_viewer.GetInputs()) { original_inputs[input] = order++; @@ -3219,22 +3150,22 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, // input order needs to be consistent with original graph's input order for (const auto& [node_arg, subgraph_input_order] : subgraph_inputs) { - const auto& original_input_it = original_inputs.find(node_arg); + const auto original_input_it = original_inputs.find(node_arg); if (original_input_it != original_inputs.end()) { - inputs.insert(std::make_pair( + inputs.emplace( original_input_it->second, // input order from original graph - node_arg)); + node_arg); } else { - inputs.insert(std::make_pair( + inputs.emplace( subgraph_input_order, // input order from subgraph - node_arg)); + node_arg); } } // Sort outputs by the order they were added - for (auto it = subgraph_outputs.begin(), end = subgraph_outputs.end(); it != end; ++it) { - outputs.insert(std::pair(it->second, it->first)); + for (const auto& [node_arg, subgraph_output_order] : subgraph_outputs) { + outputs.emplace(subgraph_output_order, node_arg); } std::unique_ptr meta_def = std::make_unique(); @@ -3259,7 +3190,8 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, } indexed_sub_graph->SetMetaDef(std::move(meta_def)); - auto new_graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + const Graph& graph = graph_viewer.GetGraph(); + auto new_graph_viewer = std::make_unique(graph, *indexed_sub_graph); std::unique_ptr result; ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(new_graph_viewer), std::move(indexed_sub_graph), result)); diff --git a/onnxruntime/python/tools/layering/layer_annotate.py b/onnxruntime/python/tools/layering/layer_annotate.py new file mode 100644 index 0000000000000..738c528b28754 --- /dev/null +++ b/onnxruntime/python/tools/layering/layer_annotate.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import logging +import pathlib + +import onnx + + +def get_logger(name, level=logging.DEBUG): + logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s") + logger = logging.getLogger(name) + logger.setLevel(level) + return logger + + +def getargs(): + argparser = argparse.ArgumentParser( + description="Read a config file with a list of node annotations and apply them to an ONNX model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + argparser.add_argument( + "--config_file_path", + type=pathlib.Path, + required=True, + help="Path to the configuration file with node annotations.", + ) + argparser.add_argument( + "--model_path", + type=pathlib.Path, + required=True, + help="Path to a single model to process.", + ) + argparser.add_argument( + "--annotated_model", + type=pathlib.Path, + required=True, + help="Path to write the annotated model to.", + ) + + return argparser.parse_args() + + +def read_annotation_config(config_file_path): + """ + Reads a configuration file to map substrings to annotations. + + The file format is expected to be: + annotation_string: substring1, substring2, ... + + The same annotation string can appear multiple times. + The node names in the configuration are treated as substrings. + + Args: + config_file_path (str or Path): Path to the configuration file. + + Returns: + list: A list of tuples (substring, annotation_string). + """ + substring_annotations = [] + with open(config_file_path) as f: + for unstripped_line in f: + line = unstripped_line.strip() + if not line: + continue + parts = line.split(":", 1) + if len(parts) < 2: + continue + annotation = parts[0].strip() + substrings = parts[1].split(",") + for substr in substrings: + substring = substr.strip() + if substring: + substring_annotations.append((substring, annotation)) + return substring_annotations + + +def process_nodes(nodes, substring_annotations): + """ + Helper function to process a list of nodes sequentially. + """ + logger = get_logger("annotate_model") + logger.info(f"Processing {len(nodes)} nodes.") + + for node in nodes: + matched_annotation = None + for substring, annotation in substring_annotations: + if substring in node.name: + matched_annotation = annotation + + if matched_annotation: + # Check if annotation already exists + entry = None + for prop in node.metadata_props: + if prop.key == "layer_ann": + entry = prop + break + + if entry: + entry.value = matched_annotation + else: + entry = node.metadata_props.add() + entry.key = "layer_ann" + entry.value = matched_annotation + + # Recurse into subgraphs for control flow nodes + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + annotate_graph(attr.g, substring_annotations) + elif attr.type == onnx.AttributeProto.GRAPHS: + for sub_graph in attr.graphs: + annotate_graph(sub_graph, substring_annotations) + + +def annotate_graph(graph, substring_annotations): + """ + Recursively applies annotations to nodes where a configured substring appears in the node name. + + This function iterates over all nodes in the given graph. It checks if any + substring from the configuration appears in the node's name. If matched, + it adds or updates a metadata property with key 'layer_ann' containing + the annotation string. If multiple substrings match, the last one defined + in the configuration list applies. + + It also handles control flow nodes (like 'If' or 'Loop') by recursively + processing their subgraphs (attributes of type GRAPH or GRAPHS). + + Args: + graph (onnx.GraphProto): The ONNX graph to process. + substring_annotations (list): A list of tuples (substring, annotation_string). + """ + process_nodes(graph.node, substring_annotations) + + +def annotate_model(model, substring_annotations): + """ + Annotates an ONNX model with metadata based on a provided mapping. + + This function serves as the entry point to annotate the model's graph. + It delegates the work to `annotate_graph`, which recursively processes + all nodes in the main graph and any nested subgraphs. + + Args: + model (onnx.ModelProto): The ONNX model to annotate. + substring_annotations (list): A list of tuples (substring, annotation_string). + """ + annotate_graph(model.graph, substring_annotations) + + +if __name__ == "__main__": + args = getargs() + logger = get_logger("annotate_model") + + # Read the mapping from the configuration file + substring_annotations = read_annotation_config(args.config_file_path) + + logger.info(f"Loading model from {args.model_path}") + onnx_model = onnx.load(args.model_path, load_external_data=False) + + logger.info(f"Applying annotations from {args.config_file_path}") + annotate_model(onnx_model, substring_annotations) + + logger.info(f"Saving annotated model to {args.annotated_model}") + onnx.save_model(onnx_model, args.annotated_model) diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 699d1b1a2c27a..9e28882b9a65d 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -662,5 +662,161 @@ TEST(FunctionTest, Test_GH_issue_16438) { status = session_object.Initialize(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); } + +// Verify that when a function node with a layering annotation is inlined, +// the inlined nodes inherit the parent function node's annotation. +TEST(FunctionTest, InlinedNodesInheritLayeringAnnotation) { + // Parse and build a Model with a local function (multi-node body: Constant + Mul). + ONNX_NAMESPACE::OnnxParser parser(basic_code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + auto& logger = DefaultLoggingManager().DefaultLogger(); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, logger)); + + Graph& graph = model->MainGraph(); + ASSERT_STATUS_OK(graph.Resolve()); + + // Find the function call node (local.myfun) and annotate it. + Node* func_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "myfun") { + func_node = &node; + break; + } + } + ASSERT_NE(func_node, nullptr) << "Could not find function call node 'myfun'"; + ASSERT_TRUE(func_node->CanBeInlined()); + + const std::string annotation = "TestLayerAnnotation"; + func_node->SetLayeringAnnotation(annotation); + + // Inline the function node. + ASSERT_STATUS_OK(graph.InlineFunction(*func_node)); + ASSERT_STATUS_OK(graph.Resolve()); + + // After inlining, the original function call node is removed and replaced + // by the function body nodes (a Mul node; the Constant becomes an initializer). + // Verify every remaining node inherited the annotation. + int node_count = 0; + for (const auto& node : graph.Nodes()) { + ++node_count; + EXPECT_EQ(node.GetLayeringAnnotation(), annotation) + << "Node '" << node.Name() << "' (op: " << node.OpType() + << ") did not inherit the parent function's layering annotation."; + } + EXPECT_GT(node_count, 0) << "Expected at least one inlined node in the graph."; +} + +// Verify that when a function node with no layering annotation is inlined, +// the inlined nodes remain unannotated. +TEST(FunctionTest, InlinedNodesNoAnnotationWhenParentUnannotated) { + ONNX_NAMESPACE::OnnxParser parser(basic_code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + auto& logger = DefaultLoggingManager().DefaultLogger(); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, logger)); + + Graph& graph = model->MainGraph(); + ASSERT_STATUS_OK(graph.Resolve()); + + Node* func_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "myfun") { + func_node = &node; + break; + } + } + ASSERT_NE(func_node, nullptr); + // Do NOT set any annotation on the function node. + ASSERT_TRUE(func_node->GetLayeringAnnotation().empty()); + + ASSERT_STATUS_OK(graph.InlineFunction(*func_node)); + ASSERT_STATUS_OK(graph.Resolve()); + + for (const auto& node : graph.Nodes()) { + EXPECT_TRUE(node.GetLayeringAnnotation().empty()) + << "Node '" << node.Name() << "' should not have a layering annotation " + << "when the parent function node was unannotated."; + } +} + +// Verify annotation inheritance with two calls to the same function, +// where each call has a different annotation. +TEST(FunctionTest, InlinedNodesInheritDistinctAnnotationsPerCallSite) { + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[N] x) => (float[N] y) + { + y1 = local.myfun (x) + y = local.myfun (y1) + } + + < + opset_import: [ "" : 16 ], + domain: "local" + > + myfun (lx) => (ly) { + two = Constant () + ly = Mul (lx, two) + } + )"; + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()); + + auto& logger = DefaultLoggingManager().DefaultLogger(); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, logger)); + + Graph& graph = model->MainGraph(); + ASSERT_STATUS_OK(graph.Resolve()); + + // Collect the two function call nodes in graph order. + std::vector func_nodes; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "myfun") { + func_nodes.push_back(&node); + } + } + ASSERT_EQ(func_nodes.size(), 2u); + + // Annotate each call site differently. + func_nodes[0]->SetLayeringAnnotation("AnnotationA"); + func_nodes[1]->SetLayeringAnnotation("AnnotationB"); + + // Inline the first call, then the second. + ASSERT_STATUS_OK(graph.InlineFunction(*func_nodes[0])); + ASSERT_STATUS_OK(graph.InlineFunction(*func_nodes[1])); + ASSERT_STATUS_OK(graph.Resolve()); + + // After inlining both calls, the graph should have nodes from both expansions. + // Each group should carry its respective annotation. + bool found_a = false; + bool found_b = false; + for (const auto& node : graph.Nodes()) { + const auto& ann = node.GetLayeringAnnotation(); + EXPECT_TRUE(ann == "AnnotationA" || ann == "AnnotationB") + << "Node '" << node.Name() << "' has unexpected annotation: '" << ann << "'"; + if (ann == "AnnotationA") found_a = true; + if (ann == "AnnotationB") found_b = true; + } + EXPECT_TRUE(found_a) << "No node found with AnnotationA"; + EXPECT_TRUE(found_b) << "No node found with AnnotationB"; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/layering_annotations_test.cc b/onnxruntime/test/framework/layering_annotations_test.cc new file mode 100644 index 0000000000000..f865be7bfc686 --- /dev/null +++ b/onnxruntime/test/framework/layering_annotations_test.cc @@ -0,0 +1,1763 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "core/framework/execution_providers.h" +#include "core/framework/ortmemoryinfo.h" +#include "core/framework/layering_annotations.h" +#include "core/session/abi_devices.h" +#include "core/framework/execution_provider.h" +#include "core/framework/ortdevice.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" // For Model, Graph +#include "gtest/gtest.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/test_environment.h" + +namespace onnxruntime { +namespace test { + +TEST(LayeringRuleMatcherTest, ExactMatches) { + LayeringRules rules; + rules.rules.push_back({"Device1", "Annotation1", false}); // Index 0 + rules.rules.push_back({"Device2", "Annotation2", false}); // Index 1 + + LayeringRuleMatcher matcher(rules); + + { + auto result = matcher.Match("Annotation1"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + { + auto result = matcher.Match("Annotation2"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); + } + { + auto result = matcher.Match("Annotation3"); + EXPECT_FALSE(result.has_value()); + } +} + +TEST(LayeringRuleMatcherTest, PrefixMatches) { + LayeringRules rules; + rules.rules.push_back({"Device1", "Prefix1", true}); // Index 0: =Prefix1 + rules.rules.push_back({"Device2", "Pre", true}); // Index 1: =Pre + + LayeringRuleMatcher matcher(rules); + + // "Prefix1Suffix" matches "Prefix1" (idx 0) and "Pre" (idx 1). 0 < 1, so 0. + { + auto result = matcher.Match("Prefix1Suffix"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + + // "PreSuffix" matches "Pre" (idx 1). "Prefix1" does not match. + { + auto result = matcher.Match("PreSuffix"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); + } + + // "Other" matches nothing + { + auto result = matcher.Match("Other"); + EXPECT_FALSE(result.has_value()); + } +} + +TEST(LayeringRuleMatcherTest, PriorityPrefixOverExact) { + // Prefix matches should take precedence over exact matches regardless of order. + + // Case 1: Prefix rule comes before Exact rule + { + LayeringRules rules; + rules.rules.push_back({"Device1", "A", true}); // Index 0: =A (Prefix) + rules.rules.push_back({"Device2", "AB", false}); // Index 1: AB (Exact) + + LayeringRuleMatcher matcher(rules); + // "AB" matches prefix "A" (idx 0) and exact "AB" (idx 1). + // Since prefix matches are checked first and returned if found, we expect 0. + auto result = matcher.Match("AB"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + + // Case 2: Exact rule comes before Prefix rule + { + LayeringRules rules; + rules.rules.push_back({"Device1", "AB", false}); // Index 0: AB (Exact) + rules.rules.push_back({"Device2", "A", true}); // Index 1: =A (Prefix) + + LayeringRuleMatcher matcher(rules); + // "AB" matches exact "AB" (idx 0) and prefix "A" (idx 1). + // Priority says Prefix matches are returned first. + auto result = matcher.Match("AB"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); + } +} + +TEST(LayeringRuleMatcherTest, LongestOrShortestPrefixPriority) { + // If multiple prefix rules match, the one with the lowest index (earliest in config) wins. + + // Case 1: Shorter prefix first + { + LayeringRules rules; + rules.rules.push_back({"Device1", "A", true}); // Index 0 + rules.rules.push_back({"Device2", "AB", true}); // Index 1 + + LayeringRuleMatcher matcher(rules); + // "ABC" matches "A" (0) and "AB" (1). Since 0 < 1, best match is 0. + auto result = matcher.Match("ABC"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + + // Case 2: Longer prefix first + { + LayeringRules rules; + rules.rules.push_back({"Device1", "AB", true}); // Index 0 + rules.rules.push_back({"Device2", "A", true}); // Index 1 + + LayeringRuleMatcher matcher(rules); + // "ABC" matches "AB" (0) and "A" (1). Since 0 < 1, best match is 0. + auto result = matcher.Match("ABC"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } +} + +TEST(LayeringRuleMatcherTest, OverlappingExactMatchPriority) { + // If duplicates exist, first one wins. + LayeringRules rules; + rules.rules.push_back({"Device1", "A", false}); // Index 0 + rules.rules.push_back({"Device2", "A", false}); // Index 1 + + LayeringRuleMatcher matcher(rules); + auto result = matcher.Match("A"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); +} + +TEST(LayeringRuleMatcherTest, OverlappingPrefixMatchPriority) { + // If duplicates exist, first one wins. + LayeringRules rules; + rules.rules.push_back({"Device1", "A", true}); // Index 0 + rules.rules.push_back({"Device2", "A", true}); // Index 1 + + LayeringRuleMatcher matcher(rules); + auto result = matcher.Match("AB"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); +} + +namespace { + +// Helper to construct OrtEpDevice wrappers for testing +struct TestEpDevice { + std::string ep_name; + OrtHardwareDevice hw_device; + bool has_hw_device = false; + OrtMemoryInfo mem_info; + bool has_mem_info = false; + + // We need to keep the structures alive while OrtEpDevice points to them + OrtEpDevice Get() const { + OrtEpDevice ep; + ep.ep_name = ep_name; + ep.device = has_hw_device ? &hw_device : nullptr; + ep.device_memory_info = has_mem_info ? &mem_info : nullptr; + return ep; + } +}; + +TestEpDevice CreateEp(const std::string& name) { + TestEpDevice ep; + ep.ep_name = name; + return ep; +} + +TestEpDevice CreateHwEp(const std::string& name, OrtHardwareDeviceType type, uint32_t vendor_id = 0, + uint32_t device_id = 0, const std::string& vendor_str = std::string()) { + TestEpDevice ep; + ep.ep_name = name; + ep.hw_device = {type, vendor_id, device_id, vendor_str, {}}; + ep.has_hw_device = true; + return ep; +} + +TestEpDevice CreateMemEp(const std::string& name, OrtDevice::DeviceType type, int device_id = 0) { + TestEpDevice ep; + ep.ep_name = name; + // Note: OrtMemoryInfo name doesn't matter for logic now, but required for ctor + ep.mem_info = OrtMemoryInfo("TestMem", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(type, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, + static_cast(device_id)), + OrtMemType::OrtMemTypeDefault); + ep.has_mem_info = true; + return ep; +} + +} // namespace + +TEST(EpLayeringMatcherTest, MatchCPU) { + LayerAnnotation rule = {"CPU", "Anno1", false}; + + // Case 1: EP Name kCpuExecutionProvider + { + auto test_ep = CreateEp(kCpuExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCpuExecutionProvider); + } + + // Case 2: Hardware Device CPU + { + auto test_ep = CreateHwEp("SomeCPU_EP", OrtHardwareDeviceType_CPU); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "SomeCPU_EP"); + } + + // Case 3: Memory Info CPU + { + auto test_ep = CreateMemEp("MemCPU_EP", OrtDevice::CPU); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MemCPU_EP"); + } +} + +TEST(EpLayeringMatcherTest, MatchGPU) { + LayerAnnotation rule = {"GPU", "Anno1", false}; + + // Case 1: Hardware Device GPU + { + auto test_ep = CreateHwEp("MyGPU_EP", OrtHardwareDeviceType_GPU); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyGPU_EP"); + } + + // Case 2: Memory Info GPU + { + auto test_ep = CreateMemEp("MemGPU_EP", OrtDevice::GPU); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MemGPU_EP"); + } + + // Case 3: Heuristic kCudaExecutionProvider + { + auto test_ep = CreateEp(kCudaExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCudaExecutionProvider); + } + + // Case 4: Heuristic kDmlExecutionProvider + { + auto test_ep = CreateEp(kDmlExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kDmlExecutionProvider); + } +} + +TEST(EpLayeringMatcherTest, MatchSpecificGPU_VendorString) { + LayerAnnotation rule = {"gpu:nvidia", "Anno1", false}; + + // Case 1: Vendor String Match + { + auto test_ep = CreateHwEp("MyNvidia_EP", OrtHardwareDeviceType_GPU, 0, 0, "NVIDIA"); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyNvidia_EP"); + } + + // Case 2: Vendor String Mismatch + { + auto test_ep = CreateHwEp("MyAMD_EP", OrtHardwareDeviceType_GPU, 0, 0, "AMD"); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + EXPECT_FALSE(result.has_value()); + } +} + +TEST(EpLayeringMatcherTest, MatchSpecificGPU_VendorId) { + LayerAnnotation rule_intel = {"gpu:intel", "Anno1", false}; + LayerAnnotation rule_nvidia = {"gpu:nvidia", "Anno2", false}; + LayerAnnotation rule_amd = {"gpu:amd", "Anno3", false}; + + // Case 1: Vendor ID Match Intel + { + auto test_ep = CreateHwEp("Intel_EP", OrtHardwareDeviceType_GPU, OrtDevice::VendorIds::INTEL); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule_intel); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "Intel_EP"); + } + + // Case 2: Vendor ID Match Nvidia + { + auto test_ep = CreateHwEp("Nvidia_EP", OrtHardwareDeviceType_GPU, OrtDevice::VendorIds::NVIDIA); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule_nvidia); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "Nvidia_EP"); + } + + // Case 3: Vendor ID Match AMD + { + auto test_ep = CreateHwEp("AMD_EP", OrtHardwareDeviceType_GPU, OrtDevice::VendorIds::AMD); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule_amd); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "AMD_EP"); + } +} + +TEST(EpLayeringMatcherTest, MatchSpecificGPU_Heuristic) { + LayerAnnotation rule = {"gpu:nvidia", "Anno1", false}; + + // Case 1: kCudaExecutionProvider -> nvidia + { + // Need an EP with GPU HW type but generic vendor info to trigger the heuristic + auto test_ep_hw = CreateHwEp(kCudaExecutionProvider, OrtHardwareDeviceType_GPU); + OrtEpDevice ep_device = test_ep_hw.Get(); + std::vector devices = {&ep_device}; + + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCudaExecutionProvider); + } +} + +TEST(EpLayeringMatcherTest, MatchSpecificGPU_Index) { + LayerAnnotation rule = {"gpu:1", "Anno1", false}; + + // Case 1: ID Match + { + auto test_ep = CreateHwEp("GPU1", OrtHardwareDeviceType_GPU, 0, 1); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "GPU1"); + } + + // Case 2: ID Mismatch + { + auto test_ep = CreateHwEp("GPU0", OrtHardwareDeviceType_GPU, 0, 0); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + EXPECT_FALSE(result.has_value()); + } +} + +TEST(EpLayeringMatcherTest, MatchAccelerator) { + LayerAnnotation rule = {"accelerator", "Anno1", false}; + + // Case 1: CPU EP should NOT match + { + auto test_ep = CreateEp(kCpuExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + EXPECT_FALSE(result.has_value()); + } + + // Case 2: Custom EP, No HW/Mem info, considered accelerator + { + auto test_ep = CreateEp("MyCustomAccel"); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyCustomAccel"); + } + + // Case 3: GPU HW is an accelerator + { + auto test_ep = CreateHwEp("MyGPU", OrtHardwareDeviceType_GPU); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyGPU"); + } +} + +TEST(EpLayeringMatcherTest, MatchNPU) { + LayerAnnotation rule = {"npu", "Anno1", false}; + + // Case 1: Hardware NPU + { + auto test_ep = CreateHwEp("MyNPU", OrtHardwareDeviceType_NPU); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyNPU"); + } + + // Case 2: QNN Heuristic + { + auto test_ep = CreateEp(kQnnExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kQnnExecutionProvider); + } +} + +TEST(EpLayeringMatcherTest, MatchFPGA) { + LayerAnnotation rule = {"fpga", "Anno1", false}; + + // Case 1: MemInfo says FPGA + { + auto test_ep = CreateMemEp("MyFPGA", OrtDevice::FPGA); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyFPGA"); + } +} + +TEST(EpLayeringMatcherTest, MatchDirectDesignators) { + LayerAnnotation rule_cuda = {"cuda", "A", false}; + LayerAnnotation rule_dml = {"dml", "B", false}; + + { + auto test_ep = CreateEp(kCudaExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule_cuda); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCudaExecutionProvider); + } + { + auto test_ep = CreateEp(kDmlExecutionProvider); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule_dml); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kDmlExecutionProvider); + } +} + +TEST(EpLayeringMatcherTest, MatchExactEPName) { + LayerAnnotation rule = {"MyCustomEP", "Anno1", false}; + + { + auto test_ep = CreateEp("MyCustomEP"); + OrtEpDevice ep_device = test_ep.Get(); + std::vector devices = {&ep_device}; + auto result = EpLayeringMatcher::Match(devices, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyCustomEP"); + } +} + +namespace { + +// Minimal concrete implementation of IExecutionProvider for testing +class MockExecutionProvider : public IExecutionProvider { + public: + MockExecutionProvider(const std::string& type, OrtDevice device) + : IExecutionProvider(type, device) {} + + std::shared_ptr GetKernelRegistry() const override { return nullptr; } +}; + +} // namespace + +TEST(EpLayeringMatcherTest, MatchExecutionProviders_CPU) { + LayerAnnotation rule = {"CPU", "Anno1", false}; + ExecutionProviders providers; + + // Add CPU provider + auto cpu_ep = std::make_shared(kCpuExecutionProvider, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCpuExecutionProvider, cpu_ep)); + + // Add a GPU provider (should be skipped for CPU rule) + auto gpu_ep = std::make_shared(kCudaExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCudaExecutionProvider, gpu_ep)); + + auto result = EpLayeringMatcher::Match(providers, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCpuExecutionProvider); +} + +TEST(EpLayeringMatcherTest, MatchExecutionProviders_GPU) { + LayerAnnotation rule = {"GPU", "Anno1", false}; + ExecutionProviders providers; + + // Add CPU provider (should be skipped) + auto cpu_ep = std::make_shared(kCpuExecutionProvider, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCpuExecutionProvider, cpu_ep)); + + // Add CUDA provider (GPU) + auto gpu_ep = std::make_shared(kCudaExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCudaExecutionProvider, gpu_ep)); + + auto result = EpLayeringMatcher::Match(providers, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCudaExecutionProvider); +} + +TEST(EpLayeringMatcherTest, MatchExecutionProviders_GPU_Specific) { + LayerAnnotation rule = {"gpu:nvidia", "Anno1", false}; // Assumes heuristics or vendor ID logic + ExecutionProviders providers; + + // Add CPU provider + auto cpu_ep = std::make_shared(kCpuExecutionProvider, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCpuExecutionProvider, cpu_ep)); + + // Add CUDA provider (NVIDIA vendor ID) + auto gpu_ep = std::make_shared(kCudaExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0)); + ASSERT_STATUS_OK(providers.Add(kCudaExecutionProvider, gpu_ep)); + + auto result = EpLayeringMatcher::Match(providers, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, kCudaExecutionProvider); +} + +TEST(EpLayeringMatcherTest, MatchExecutionProviders_NoMatch) { + LayerAnnotation rule = {"GPU", "Anno1", false}; + ExecutionProviders providers; + + // Only CPU provider available + auto cpu_ep = std::make_shared(kCpuExecutionProvider, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCpuExecutionProvider, cpu_ep)); + + auto result = EpLayeringMatcher::Match(providers, rule); + EXPECT_FALSE(result.has_value()); +} + +TEST(EpLayeringMatcherTest, MatchExecutionProviders_Accelerator) { + LayerAnnotation rule = {"accelerator", "Anno1", false}; + ExecutionProviders providers; + + // Add CPU + auto cpu_ep = std::make_shared(kCpuExecutionProvider, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add(kCpuExecutionProvider, cpu_ep)); + + // Add custom accelerator + auto accel_ep = std::make_shared("MyAccel", OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, 0, 0)); + ASSERT_STATUS_OK(providers.Add("MyAccel", accel_ep)); + + auto result = EpLayeringMatcher::Match(providers, rule); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, "MyAccel"); +} + +TEST(LayeringIndexTest, AssignNodesBasedOnAnnotations) { + // 1. Setup Graph + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + // Create nodes + // Node 0: "AnnotatedNode" -> Annotated with "RuleA" + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output_arg0 = &graph.GetOrCreateNodeArg("output0", &type_proto); + Node& node0 = graph.AddNode("node0", "Abs", "Node 0", {input_arg}, {output_arg0}); + node0.SetLayeringAnnotation("RuleA"); + + // Node 1: "UnannotatedNode" -> No annotation + NodeArg* output_arg1 = &graph.GetOrCreateNodeArg("output1", &type_proto); + Node& node1 = graph.AddNode("node1", "Abs", "Node 1", {output_arg0}, {output_arg1}); + // No annotation + + // Node 2: "AnnotatedNode2" -> Annotated with "RuleB" + NodeArg* output_arg2 = &graph.GetOrCreateNodeArg("output2", &type_proto); + Node& node2 = graph.AddNode("node2", "Abs", "Node 2", {output_arg1}, {output_arg2}); + node2.SetLayeringAnnotation("RuleB"); + + ASSERT_STATUS_OK(graph.Resolve()); + + // 2. Setup Rules and Matcher + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + rules.rules.push_back({"DeviceB", "RuleB", false}); // Index 1 + LayeringRuleMatcher matcher(rules); + + // 3. Setup Pre-computed Mappings (simulating Partitioning Manager) + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + ep_map["DeviceB"].insert(1); + + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + rule_map[1] = "DeviceB"; + + // 4. Create LayeringIndex + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // 5. Verify Assignments + // Node 0: Annotated "RuleA" -> Index 0 -> DeviceA + auto assign0 = index.GetNodeAssignment(graph, node0.Index()); + ASSERT_TRUE(assign0.has_value()); + EXPECT_EQ(*assign0, 0u); + + // Node 1: Unannotated -> Should generally map to nothing (unless defaulting logic exists, + // but current impl leaves unannotated in main graph as unassigned) + auto assign1 = index.GetNodeAssignment(graph, node1.Index()); + EXPECT_FALSE(assign1.has_value()); + + // Node 2: Annotated "RuleB" -> Index 1 -> DeviceB + auto assign2 = index.GetNodeAssignment(graph, node2.Index()); + ASSERT_TRUE(assign2.has_value()); + EXPECT_EQ(*assign2, 1u); +} + +TEST(LayeringIndexTest, AssignNodeWithInvalidEpMapping) { + // Scenario: Node annotated with a rule that maps to an EP that is NOT present/valid + + // 1. Setup Graph with one node annotated "RuleX" + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& node = graph.AddNode("node", "Abs", "Node", {input_arg}, {output_arg}); + node.SetLayeringAnnotation("RuleX"); + + ASSERT_STATUS_OK(graph.Resolve()); + + // 2. Setup Rules: RuleX exists at index 0, maps to "PhantomDevice" + LayeringRules rules; + rules.rules.push_back({"PhantomDevice", "RuleX", false}); // Index 0 + + // 3. Setup Mappings: But "PhantomDevice" is NOT in the mappings (simulating EP unavailable) + LayeringIndex::EpNameToLayeringIndices ep_map; + // ep_map["PhantomDevice"] is empty/missing + + LayeringIndex::LayeringIndexToEpName rule_map; + // rule_map[0] is missing + + // 4. Create Index + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + // 5. Verify: Node should NOT be assigned because the mapped EP is missing + auto assign = index.GetNodeAssignment(graph, node.Index()); + EXPECT_FALSE(assign.has_value()); +} + +TEST(LayeringIndexTest, SubgraphInheritance) { + // Scenario: Annotated Node containing a subgraph. + // Nodes inside subgraph (unannotated) should inherit parent's assignment. + + // 1. Setup Parent Graph + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + NodeArg* cond_arg = &graph.GetOrCreateNodeArg("cond", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + // Create "If" node + Node& if_node = graph.AddNode("if_node", "If", "If Node", {cond_arg}, {output_arg}); + if_node.SetLayeringAnnotation("RuleA"); // Annotate Parent + + auto build_subgraph = [](ONNX_NAMESPACE::GraphProto& proto, const std::string& graph_name, + const std::string& node_name, const std::string& input_name, const std::string& output_name) { + proto.set_name(graph_name); + // Inputs: Implicit from outer scope for 'cond' + + auto* node = proto.add_node(); + node->set_name(node_name); + node->set_op_type("Identity"); + node->add_input(input_name); + node->add_output(output_name); + + auto* out_vi = proto.add_output(); + out_vi->set_name(output_name); + out_vi->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + }; + + // Create Subgraph (then_branch) + ONNX_NAMESPACE::GraphProto then_graph_proto; + build_subgraph(then_graph_proto, "then_graph", "sub_node", "cond", "sub_out"); + if_node.AddAttribute("then_branch", then_graph_proto); + + // Create 'else_branch' + ONNX_NAMESPACE::GraphProto else_graph_proto; + build_subgraph(else_graph_proto, "else_graph", "else_sub_node", "cond", "else_sub_out"); + if_node.AddAttribute("else_branch", else_graph_proto); + + // First Resolve to create subgraph instances + ASSERT_STATUS_OK(graph.Resolve()); + + // Get subgraph instances (checked to ensure they exist) + Graph* then_graph = if_node.GetMutableGraphAttribute("then_branch"); + ASSERT_NE(then_graph, nullptr); + Graph* else_graph = if_node.GetMutableGraphAttribute("else_branch"); + ASSERT_NE(else_graph, nullptr); + + // 2. Setup Rules + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + + // 3. Create Index + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // 4. Verify Parent Assignment + auto assign_parent = index.GetNodeAssignment(graph, if_node.Index()); + ASSERT_TRUE(assign_parent.has_value()); + EXPECT_EQ(*assign_parent, 0u); + + // 5. Verify Subgraph Node Assignment (Inheritance) + bool validated_then = false; + for (const auto& node : then_graph->Nodes()) { + if (node.OpType() == "Identity") { + auto assign_sub = index.GetNodeAssignment(*then_graph, node.Index()); + ASSERT_TRUE(assign_sub.has_value()) << "Subgraph node should inherit parent annotation"; + EXPECT_EQ(*assign_sub, 0u); + validated_then = true; + } + } + ASSERT_TRUE(validated_then); +} + +TEST(LayeringIndexTest, SubgraphOverride) { + // Scenario: Annotated Node containing a subgraph. + // Node inside subgraph HAS annotation -> Should override inheritance. + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + NodeArg* cond_arg = &graph.GetOrCreateNodeArg("cond", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& if_node = graph.AddNode("if_node", "If", "If Node", {cond_arg}, {output_arg}); + if_node.SetLayeringAnnotation("RuleA"); // Annotate Parent = Rule A (Index 0) + + auto build_subgraph = [](ONNX_NAMESPACE::GraphProto& proto, const std::string& graph_name, + const std::string& node_name, const std::string& input_name, const std::string& output_name) { + proto.set_name(graph_name); + + auto* node = proto.add_node(); + node->set_name(node_name); + node->set_op_type("Identity"); + node->add_input(input_name); + node->add_output(output_name); + + auto* out_vi = proto.add_output(); + out_vi->set_name(output_name); + out_vi->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + }; + + ONNX_NAMESPACE::GraphProto then_graph_proto; + build_subgraph(then_graph_proto, "then_graph", "sub_node", "cond", "sub_out"); + if_node.AddAttribute("then_branch", then_graph_proto); + + ONNX_NAMESPACE::GraphProto else_graph_proto; + build_subgraph(else_graph_proto, "else_graph", "else_sub_node", "cond", "else_sub_out"); + if_node.AddAttribute("else_branch", else_graph_proto); + + ASSERT_STATUS_OK(graph.Resolve()); + + Graph* then_graph = if_node.GetMutableGraphAttribute("then_branch"); + ASSERT_NE(then_graph, nullptr); + + // Find sub_node to set annotation + Node* sub_node = nullptr; + for (auto& node : then_graph->Nodes()) { + if (node.Name() == "sub_node") { + sub_node = &node; + break; + } + } + ASSERT_NE(sub_node, nullptr); + + // OVERRIDE: Annotate sub_node with Rule B + sub_node->SetLayeringAnnotation("RuleB"); + + // Rules: RuleA(0)->DeviceA, RuleB(1)->DeviceB + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); + rules.rules.push_back({"DeviceB", "RuleB", false}); + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + ep_map["DeviceB"].insert(1); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + rule_map[1] = "DeviceB"; + + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // Verify Parent = 0 + auto assign_parent = index.GetNodeAssignment(graph, if_node.Index()); + ASSERT_TRUE(assign_parent.has_value()); + EXPECT_EQ(*assign_parent, 0u); + + // Verify Sub = 1 (Override) + auto assign_sub = index.GetNodeAssignment(*then_graph, sub_node->Index()); + ASSERT_TRUE(assign_sub.has_value()); + EXPECT_EQ(*assign_sub, 1u); +} + +TEST(LayeringIndexTest, UpdateIndex) { + // 1. Setup Graph with one node + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& node = graph.AddNode("node", "Abs", "Node", {input_arg}, {output_arg}); + ASSERT_STATUS_OK(graph.Resolve()); + + // 2. Setup Rules and Index + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + + // Creates index (node has no annotation, so not assigned) + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + EXPECT_FALSE(index.GetNodeAssignment(graph, node.Index()).has_value()); + + // 3. Update Node with Annotation + node.SetLayeringAnnotation("RuleA"); + + // 4. Call Update + std::vector nodes_to_update = {node.Index()}; + index.Update(graph, nodes_to_update); + + // 5. Verify Assignment + auto assignment = index.GetNodeAssignment(graph, node.Index()); + ASSERT_TRUE(assignment.has_value()); + EXPECT_EQ(*assignment, 0u); +} + +TEST(LayeringRulesTest, LayeringRulesParsing) { + // Test empty string + { + LayeringRules rules; + ASSERT_STATUS_OK(LayeringRules::FromConfigString("", rules)); + EXPECT_TRUE(rules.rules.empty()); + } + + // Test simple valid string + { + LayeringRules rules; + ASSERT_STATUS_OK(LayeringRules::FromConfigString("EP1(Annotation1)", rules)); + ASSERT_EQ(rules.rules.size(), 1u); + EXPECT_EQ(rules.rules[0].device, "EP1"); + EXPECT_EQ(rules.rules[0].annotation, "Annotation1"); + EXPECT_TRUE(rules.rules[0].prefix_match); + } + + // Test multiple annotations for one device + { + LayeringRules rules; + ASSERT_STATUS_OK(LayeringRules::FromConfigString("EP1(Annotation1, Annotation2)", rules)); + ASSERT_EQ(rules.rules.size(), 2u); + EXPECT_EQ(rules.rules[0].device, "EP1"); + EXPECT_EQ(rules.rules[0].annotation, "Annotation1"); + EXPECT_TRUE(rules.rules[0].prefix_match); + EXPECT_EQ(rules.rules[1].device, "EP1"); + EXPECT_EQ(rules.rules[1].annotation, "Annotation2"); + EXPECT_TRUE(rules.rules[1].prefix_match); + } + + // Test multiple devices + { + LayeringRules rules; + ASSERT_STATUS_OK(LayeringRules::FromConfigString("EP1(Annotation1); EP2(Annotation2)", rules)); + ASSERT_EQ(rules.rules.size(), 2u); + EXPECT_EQ(rules.rules[0].device, "EP1"); + EXPECT_EQ(rules.rules[0].annotation, "Annotation1"); + EXPECT_TRUE(rules.rules[0].prefix_match); + EXPECT_EQ(rules.rules[1].device, "EP2"); + EXPECT_EQ(rules.rules[1].annotation, "Annotation2"); + EXPECT_TRUE(rules.rules[1].prefix_match); + } + + // Test exact match + { + LayeringRules rules; + ASSERT_STATUS_OK(LayeringRules::FromConfigString("EP1(=Annotation1)", rules)); + ASSERT_EQ(rules.rules.size(), 1u); + EXPECT_EQ(rules.rules[0].device, "EP1"); + EXPECT_EQ(rules.rules[0].annotation, "Annotation1"); + EXPECT_FALSE(rules.rules[0].prefix_match); + } + + // Test trimming whitespace + { + LayeringRules rules; + ASSERT_STATUS_OK(LayeringRules::FromConfigString(" EP1 ( Annotation1 , =Annotation2 ) ; EP2 ( Annotation3 ) ", rules)); + ASSERT_EQ(rules.rules.size(), 3u); + EXPECT_EQ(rules.rules[0].device, "EP1"); + EXPECT_EQ(rules.rules[0].annotation, "Annotation1"); + EXPECT_TRUE(rules.rules[0].prefix_match); + EXPECT_EQ(rules.rules[1].device, "EP1"); + EXPECT_EQ(rules.rules[1].annotation, "Annotation2"); + EXPECT_FALSE(rules.rules[1].prefix_match); + EXPECT_EQ(rules.rules[2].device, "EP2"); + EXPECT_EQ(rules.rules[2].annotation, "Annotation3"); + EXPECT_TRUE(rules.rules[2].prefix_match); + } +} + +TEST(LayeringRulesTest, FromConfigString_InvalidFormat) { + LayeringRules rules; + + // Error: Missing parentheses structure entirely + EXPECT_FALSE(LayeringRules::FromConfigString("Device1Annotation1", rules).IsOK()); + + // Error: Missing closing parenthesis + EXPECT_FALSE(LayeringRules::FromConfigString("Device1(Annotation1", rules).IsOK()); + + // Error: Missing opening parenthesis (or only closing present) + EXPECT_FALSE(LayeringRules::FromConfigString("Device1Annotation1)", rules).IsOK()); + + // Error: Parentheses reversed + EXPECT_FALSE(LayeringRules::FromConfigString("Device1)Annotation1(", rules).IsOK()); + + // Error: Empty device name (starts with parenthesis) + EXPECT_FALSE(LayeringRules::FromConfigString("(Annotation1)", rules).IsOK()); +} + +TEST(LayeringRulesTest, FromConfigString_IgnoresEmptyEntries) { + LayeringRules rules; + // "; ;" should result in 0 rules but Status::OK + ASSERT_STATUS_OK(LayeringRules::FromConfigString("; ;", rules)); + EXPECT_TRUE(rules.rules.empty()); +} + +TEST(LayeringRulesTest, FromConfigString_RejectsDuplicateAnnotations) { + LayeringRules rules; + + // Duplicate prefix annotation within the same device + EXPECT_FALSE(LayeringRules::FromConfigString("EP1(Ann1, Ann1)", rules).IsOK()); + + // Duplicate prefix annotation across different devices + EXPECT_FALSE(LayeringRules::FromConfigString("EP1(Ann1); EP2(Ann1)", rules).IsOK()); + + // Duplicate exact annotation within the same device + EXPECT_FALSE(LayeringRules::FromConfigString("EP1(=Ann1, =Ann1)", rules).IsOK()); + + // Duplicate exact annotation across different devices + EXPECT_FALSE(LayeringRules::FromConfigString("EP1(=Ann1); EP2(=Ann1)", rules).IsOK()); + + // Same annotation but different match types (prefix vs exact) should be OK + ASSERT_STATUS_OK(LayeringRules::FromConfigString("EP1(Ann1, =Ann1)", rules)); + ASSERT_EQ(rules.rules.size(), 2u); + EXPECT_TRUE(rules.rules[0].prefix_match); + EXPECT_FALSE(rules.rules[1].prefix_match); +} + +TEST(LayeringIndexTest, MakeNodeUnassigned_PreservesEpRuleMapping) { + // Scenario: All nodes for a rule are unassigned in one graph. + // ep_name_to_layering_indices_ must still contain the rule so that + // sibling subgraphs (or the same graph on a subsequent pass) can still + // use it for filtering. + + // 1. Setup Graph with two nodes, both annotated with the same rule + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + // Create nodes + // Node 0: "AnnotatedNode" -> Annotated with "RuleA" + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* mid_arg = &graph.GetOrCreateNodeArg("mid", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& node0 = graph.AddNode("node0", "Abs", "Node 0", {input_arg}, {mid_arg}); + node0.SetLayeringAnnotation("RuleA"); + Node& node1 = graph.AddNode("node1", "Abs", "Node 1", {mid_arg}, {output_arg}); + node1.SetLayeringAnnotation("RuleA"); + + ASSERT_STATUS_OK(graph.Resolve()); + + // 2. Setup Rules: RuleA -> DeviceA + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + + // 3. Create Index + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // Both nodes should be assigned + ASSERT_TRUE(index.GetNodeAssignment(graph, node0.Index()).has_value()); + ASSERT_TRUE(index.GetNodeAssignment(graph, node1.Index()).has_value()); + + // 3. Unassign both nodes (simulating EP failing to claim them) + index.MakeNodeUnassigned(graph, node0.Index()); + index.MakeNodeUnassigned(graph, node1.Index()); + + // Nodes should be unassigned + EXPECT_FALSE(index.GetNodeAssignment(graph, node0.Index()).has_value()); + EXPECT_FALSE(index.GetNodeAssignment(graph, node1.Index()).has_value()); + + // 4. CRITICAL: ep_name_to_layering_indices_ must still map DeviceA -> {0} + // so that other graphs/passes can still use this rule for filtering. + auto rules_opt = index.GetLayeringRulesForThisEp("DeviceA"); + ASSERT_TRUE(rules_opt.has_value()) << "EP-to-rule mapping should not be erased when nodes are unassigned"; + EXPECT_EQ(rules_opt->get().count(0), 1u); +} + +TEST(LayeringIndexTest, UpdateAfterFullUnassignment_RestoresVisibility) { + // Scenario: All nodes for a rule are unassigned, then Update() adds + // a new node matching the same rule. The new node must be visible + // to the EP via GetLayeringRulesForThisEp. + + // 1. Setup Graph with one annotated node + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& node0 = graph.AddNode("node0", "Abs", "Node 0", {input_arg}, {output_arg}); + node0.SetLayeringAnnotation("RuleA"); + ASSERT_STATUS_OK(graph.Resolve()); + + // 2. Setup Rules: RuleA -> DeviceA + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + ASSERT_TRUE(index.GetNodeAssignment(graph, node0.Index()).has_value()); + + // 3. Unassign the only node + index.MakeNodeUnassigned(graph, node0.Index()); + EXPECT_FALSE(index.GetNodeAssignment(graph, node0.Index()).has_value()); + + // 4. Simulate layout transform adding a new node with inherited annotation + NodeArg* new_output_arg = &graph.GetOrCreateNodeArg("new_output", &type_proto); + Node& new_node = graph.AddNode("new_node", "Abs", "Node with inherited assignment", + {output_arg}, {new_output_arg}); + new_node.SetLayeringAnnotation("RuleA"); // Inherits parent's annotation + ASSERT_STATUS_OK(graph.Resolve()); + + // Record the new node index + NodeIndex new_node_index = new_node.Index(); + + // 5. Update index with the new node + std::vector new_nodes = {new_node_index}; + index.Update(graph, new_nodes); + + // 6. New node should be assigned to rule 0 + auto assign = index.GetNodeAssignment(graph, new_node.Index()); + ASSERT_TRUE(assign.has_value()); + EXPECT_EQ(*assign, 0u); + + // 7. CRITICAL: The rule must still be visible for DeviceA + auto rules_opt = index.GetLayeringRulesForThisEp("DeviceA"); + ASSERT_TRUE(rules_opt.has_value()) << "EP-to-rule mapping must be intact for Update to be effective"; + EXPECT_EQ(rules_opt->get().count(0), 1u); +} + +// ============================================================================ +// Tests for graph_partitioner.cc LayeringIndex integration +// These tests exercise behaviors from GetCapabilityForEP, InlineNodes, and +// the partitioning pipeline when a LayeringIndex is present. +// ============================================================================ + +// Helper to create a simple linear graph: input -> node0 -> node1 -> ... -> output +namespace { + +struct SimpleGraphHelper { + std::unique_ptr model; + Graph* graph = nullptr; + std::vector node_indices; + + static SimpleGraphHelper Create(int num_nodes, const std::string& op_type = "Abs") { + SimpleGraphHelper h; + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + h.model = std::make_unique("test_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + h.graph = &h.model->MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + NodeArg* prev_arg = &h.graph->GetOrCreateNodeArg("input", &type_proto); + + for (int i = 0; i < num_nodes; ++i) { + std::string out_name = (i == num_nodes - 1) ? "output" : "mid_" + std::to_string(i); + NodeArg* out_arg = &h.graph->GetOrCreateNodeArg(out_name, &type_proto); + Node& node = h.graph->AddNode("node_" + std::to_string(i), op_type, + "Node " + std::to_string(i), {prev_arg}, {out_arg}); + h.node_indices.push_back(node.Index()); + prev_arg = out_arg; + } + return h; + } +}; + +LayeringIndex CreateTwoEpIndex(const Graph& graph, + const std::string& ep_a, const std::string& annotation_a, + const std::string& ep_b, const std::string& annotation_b) { + LayeringRules rules; + rules.rules.push_back({ep_a, annotation_a, false}); // Index 0 + rules.rules.push_back({ep_b, annotation_b, false}); // Index 1 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map[ep_a].insert(0); + ep_map[ep_b].insert(1); + + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = ep_a; + rule_map[1] = ep_b; + + return LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); +} + +} // namespace + +TEST(LayeringIndexPartitionerTest, FilteredGraphViewerExcludesOtherEpNodes) { + // Validates the filtering logic in create_graph_viewer (GetCapabilityForEP): + // When layering_index is present, nodes assigned to other EPs should be excluded + // from the GraphViewer presented to the current EP. + + // Setup: 3-node chain, node0 -> RuleA (DeviceA), node1 -> unannotated, node2 -> RuleB (DeviceB) + auto h = SimpleGraphHelper::Create(3); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + auto* node2 = h.graph->GetNode(h.node_indices[2]); + node0->SetLayeringAnnotation("RuleA"); + node2->SetLayeringAnnotation("RuleB"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + auto index = CreateTwoEpIndex(*h.graph, "DeviceA", "RuleA", "DeviceB", "RuleB"); + + // Verify: From DeviceA's perspective, node2 should be excluded + auto rules_a = index.GetLayeringRulesForThisEp("DeviceA"); + ASSERT_TRUE(rules_a.has_value()); + + // node0 should be assigned to rule 0 (DeviceA) + auto assign0 = index.GetNodeAssignment(*h.graph, h.node_indices[0]); + ASSERT_TRUE(assign0.has_value()); + EXPECT_EQ(*assign0, 0u); + + // node1 should be unassigned (available to any EP) + auto assign1 = index.GetNodeAssignment(*h.graph, h.node_indices[1]); + EXPECT_FALSE(assign1.has_value()); + + // node2 should be assigned to rule 1 (DeviceB) + auto assign2 = index.GetNodeAssignment(*h.graph, h.node_indices[2]); + ASSERT_TRUE(assign2.has_value()); + EXPECT_EQ(*assign2, 1u); + + // Simulate the filtering logic from create_graph_viewer: + // For DeviceA: include nodes with no assignment OR assignment in DeviceA's rules + InlinedVector filtered_for_device_a; + for (auto& node : h.graph->Nodes()) { + auto rule_idx_opt = index.GetNodeAssignment(*h.graph, node.Index()); + bool include = true; + if (rule_idx_opt) { + // Node has assignment - include only if it belongs to DeviceA + if (rules_a->get().count(*rule_idx_opt) == 0) { + include = false; + } + } + if (include) { + filtered_for_device_a.push_back(&node); + } + } + + // DeviceA should see node0 (assigned to it) and node1 (unassigned), but NOT node2 + EXPECT_EQ(filtered_for_device_a.size(), 2u); + bool found_node0 = false, found_node1 = false, found_node2 = false; + for (const auto* n : filtered_for_device_a) { + if (n->Index() == h.node_indices[0]) found_node0 = true; + if (n->Index() == h.node_indices[1]) found_node1 = true; + if (n->Index() == h.node_indices[2]) found_node2 = true; + } + EXPECT_TRUE(found_node0) << "DeviceA's assigned node should be included"; + EXPECT_TRUE(found_node1) << "Unassigned node should be included for any EP"; + EXPECT_FALSE(found_node2) << "DeviceB's assigned node should be excluded from DeviceA's view"; +} + +TEST(LayeringIndexPartitionerTest, FilteredGraphViewerForDeviceBExcludesDeviceANodes) { + // Mirror of the above test but from DeviceB's perspective. + + auto h = SimpleGraphHelper::Create(3); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + auto* node2 = h.graph->GetNode(h.node_indices[2]); + node0->SetLayeringAnnotation("RuleA"); + node2->SetLayeringAnnotation("RuleB"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + auto index = CreateTwoEpIndex(*h.graph, "DeviceA", "RuleA", "DeviceB", "RuleB"); + + auto rules_b = index.GetLayeringRulesForThisEp("DeviceB"); + ASSERT_TRUE(rules_b.has_value()); + + // Simulate filtering for DeviceB + InlinedVector filtered_for_device_b; + for (auto& node : h.graph->Nodes()) { + auto rule_idx_opt = index.GetNodeAssignment(*h.graph, node.Index()); + bool include = true; + if (rule_idx_opt) { + if (rules_b->get().count(*rule_idx_opt) == 0) { + include = false; + } + } + if (include) { + filtered_for_device_b.push_back(&node); + } + } + + // DeviceB should see node1 (unassigned) and node2 (assigned to it), but NOT node0 + EXPECT_EQ(filtered_for_device_b.size(), 2u); + bool found_node0 = false, found_node1 = false, found_node2 = false; + for (const auto* n : filtered_for_device_b) { + if (n->Index() == h.node_indices[0]) found_node0 = true; + if (n->Index() == h.node_indices[1]) found_node1 = true; + if (n->Index() == h.node_indices[2]) found_node2 = true; + } + EXPECT_FALSE(found_node0) << "DeviceA's assigned node should be excluded from DeviceB's view"; + EXPECT_TRUE(found_node1) << "Unassigned node should be included for any EP"; + EXPECT_TRUE(found_node2) << "DeviceB's assigned node should be included"; +} + +TEST(LayeringIndexPartitionerTest, ResetUnclaimedNodesRemovesAssignment) { + // Validates the reset_assignment_unclaimed_nodes logic: + // Nodes that were pre-assigned to an EP via layering but NOT claimed in capabilities + // should be unassigned so subsequent EPs can pick them up. + + auto h = SimpleGraphHelper::Create(4); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + auto* node1 = h.graph->GetNode(h.node_indices[1]); + auto* node2 = h.graph->GetNode(h.node_indices[2]); + + node0->SetLayeringAnnotation("RuleA"); + node1->SetLayeringAnnotation("RuleA"); + node2->SetLayeringAnnotation("RuleA"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + + auto index = LayeringIndex::Create(*h.graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // All 3 nodes should be assigned initially + ASSERT_TRUE(index.GetNodeAssignment(*h.graph, h.node_indices[0]).has_value()); + ASSERT_TRUE(index.GetNodeAssignment(*h.graph, h.node_indices[1]).has_value()); + ASSERT_TRUE(index.GetNodeAssignment(*h.graph, h.node_indices[2]).has_value()); + + // Simulate: EP only claims node0 and node2 (not node1) + InlinedHashSet claimed; + claimed.insert(h.node_indices[0]); + claimed.insert(h.node_indices[2]); + + auto ep_rules_opt = index.GetLayeringRulesForThisEp("DeviceA"); + ASSERT_TRUE(ep_rules_opt.has_value()); + const auto& ep_rules = ep_rules_opt->get(); + + // Replicate reset_assignment_unclaimed_nodes logic: + // For each assigned-filtered-in node, if not claimed, unassign it + std::vector assigned_filtered_in = {h.node_indices[0], h.node_indices[1], h.node_indices[2]}; + for (auto node_index : assigned_filtered_in) { + if (claimed.count(node_index) == 0) { + auto rule_idx_opt = index.GetNodeAssignment(*h.graph, node_index); + if (rule_idx_opt && ep_rules.count(*rule_idx_opt) > 0) { + index.MakeNodeUnassigned(*h.graph, node_index); + } + } + } + + // node0 and node2 should still be assigned + EXPECT_TRUE(index.GetNodeAssignment(*h.graph, h.node_indices[0]).has_value()); + EXPECT_TRUE(index.GetNodeAssignment(*h.graph, h.node_indices[2]).has_value()); + // node1 should be unassigned (not claimed by EP) + EXPECT_FALSE(index.GetNodeAssignment(*h.graph, h.node_indices[1]).has_value()); +} + +TEST(LayeringIndexPartitionerTest, UpdateAfterLayoutTransformAddsNewNodes) { + // Validates the LayeringIndex update after layout transformation creates new nodes. + // In GetCapabilityForEP, after layout transform, new nodes with inherited annotations + // are added and the index is updated. + + auto h = SimpleGraphHelper::Create(1); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + node0->SetLayeringAnnotation("RuleA"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + auto index = CreateTwoEpIndex(*h.graph, "DeviceA", "RuleA", "DeviceB", "RuleB"); + + // Record the max node index before "layout transformation" + const NodeIndex first_new_node = h.graph->MaxNodeIndex(); + + // Simulate layout transformation adding new nodes with inherited annotation + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* extra_out = &h.graph->GetOrCreateNodeArg("extra_output", &type_proto); + NodeArg* output_arg = &h.graph->GetOrCreateNodeArg("output", nullptr); // reuse existing + Node& new_node = h.graph->AddNode("new_node", "Abs", "Node with inherited annotation", + {output_arg}, {extra_out}); + new_node.SetLayeringAnnotation("RuleA"); // Inherits parent's annotation + ASSERT_STATUS_OK(h.graph->Resolve()); + + const NodeIndex end_node = h.graph->MaxNodeIndex(); + + // Collect new node indices (as done in graph_partitioner.cc) + InlinedVector new_node_indices; + for (NodeIndex idx = first_new_node; idx < end_node; ++idx) { + if (h.graph->GetNode(idx) != nullptr) { + new_node_indices.push_back(idx); + } + } + + // Update index + ASSERT_FALSE(new_node_indices.empty()); + index.Update(*h.graph, new_node_indices); + + // New node should be assigned to rule 0 (DeviceA) + auto assign = index.GetNodeAssignment(*h.graph, new_node.Index()); + ASSERT_TRUE(assign.has_value()); + EXPECT_EQ(*assign, 0u); + + // And the annotation string should be on the node + EXPECT_EQ(new_node.GetLayeringAnnotation(), "RuleA"); +} + +TEST(LayeringIndexPartitionerTest, UpdateWithUnannotatedNewNodeRemainsUnassigned) { + // New nodes created by layout transform that do NOT have annotations + // should remain unassigned after Update. + + auto h = SimpleGraphHelper::Create(1); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + node0->SetLayeringAnnotation("RuleA"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + auto index = CreateTwoEpIndex(*h.graph, "DeviceA", "RuleA", "DeviceB", "RuleB"); + + // Add a new node WITHOUT annotation + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* extra_out = &h.graph->GetOrCreateNodeArg("extra_output", &type_proto); + NodeArg* output_arg = &h.graph->GetOrCreateNodeArg("output", nullptr); + Node& new_node = h.graph->AddNode("unannotated_node", "Abs", "No annotation", + {output_arg}, {extra_out}); + // Deliberately NOT setting annotation + ASSERT_STATUS_OK(h.graph->Resolve()); + + std::vector new_nodes = {new_node.Index()}; + index.Update(*h.graph, new_nodes); + + // New node should remain unassigned + auto assign = index.GetNodeAssignment(*h.graph, new_node.Index()); + EXPECT_FALSE(assign.has_value()); +} + +TEST(LayeringIndexPartitionerTest, InlineAnnotationMaterialization) { + // Validates the InlineNodes logic where a node has an inherited-only assignment + // (no explicit annotation string) and the annotation is materialized before inlining. + // This tests the code path: + // if (layering_index != nullptr && !has_explicit_annotation) { + // auto rule_idx = layering_index->GetNodeAssignment(graph, node->Index()); + // if (rule_idx) { ... node->SetLayeringAnnotation(rules.rules[*rule_idx].annotation); } + // } + + // Setup: A graph where a node is assigned via inheritance (subgraph scenario) + // but has no explicit annotation string on it. + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + // Create a node without explicit annotation + Node& node = graph.AddNode("inherited_node", "Abs", "Node with inherited assignment", + {input_arg}, {output_arg}); + ASSERT_STATUS_OK(graph.Resolve()); + + // Create index where the node is somehow assigned (e.g., through inheritance) + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA", false}); // Index 0 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // The node has no annotation, so it shouldn't be assigned yet + ASSERT_TRUE(node.GetLayeringAnnotation().empty()); + EXPECT_FALSE(index.GetNodeAssignment(graph, node.Index()).has_value()); + + // Now simulate what InlineNodes does: manually annotate and update + // This simulates the case where GetNodeAssignment returns a value + // for a node in a subgraph that inherited its parent's assignment. + node.SetLayeringAnnotation("RuleA"); + std::vector updated = {node.Index()}; + index.Update(graph, updated); + + // After materialization + update, the node should be properly assigned + auto assign = index.GetNodeAssignment(graph, node.Index()); + ASSERT_TRUE(assign.has_value()); + EXPECT_EQ(*assign, 0u); + + // And the annotation string should be on the node + EXPECT_EQ(node.GetLayeringAnnotation(), "RuleA"); +} + +TEST(LayeringIndexPartitionerTest, UpdateBatchMultipleNewAnnotatedNodes) { + // Tests that Update correctly handles a batch of multiple new nodes, + // some annotated with different rules. This mirrors the behavior after + // layout transformation creates several new nodes. + + auto h = SimpleGraphHelper::Create(1); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + node0->SetLayeringAnnotation("RuleA"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + auto index = CreateTwoEpIndex(*h.graph, "DeviceA", "RuleA", "DeviceB", "RuleB"); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + // Add 3 new nodes: one for RuleA, one for RuleB, one unannotated + NodeArg* out1 = &h.graph->GetOrCreateNodeArg("new_out1", &type_proto); + NodeArg* out2 = &h.graph->GetOrCreateNodeArg("new_out2", &type_proto); + NodeArg* out3 = &h.graph->GetOrCreateNodeArg("new_out3", &type_proto); + NodeArg* output = &h.graph->GetOrCreateNodeArg("output", nullptr); + + Node& new_a = h.graph->AddNode("new_a", "Abs", "", {output}, {out1}); + new_a.SetLayeringAnnotation("RuleA"); + + Node& new_b = h.graph->AddNode("new_b", "Abs", "", {out1}, {out2}); + new_b.SetLayeringAnnotation("RuleB"); + + Node& new_none = h.graph->AddNode("new_none", "Abs", "", {out2}, {out3}); + // No annotation + + ASSERT_STATUS_OK(h.graph->Resolve()); + + std::vector new_nodes = {new_a.Index(), new_b.Index(), new_none.Index()}; + index.Update(*h.graph, new_nodes); + + // new_a -> RuleA -> rule index 0 + auto assign_a = index.GetNodeAssignment(*h.graph, new_a.Index()); + ASSERT_TRUE(assign_a.has_value()); + EXPECT_EQ(*assign_a, 0u); + + // new_b -> RuleB -> rule index 1 + auto assign_b = index.GetNodeAssignment(*h.graph, new_b.Index()); + ASSERT_TRUE(assign_b.has_value()); + EXPECT_EQ(*assign_b, 1u); + + // new_none -> unassigned + auto assign_none = index.GetNodeAssignment(*h.graph, new_none.Index()); + EXPECT_FALSE(assign_none.has_value()); +} + +TEST(LayeringIndexPartitionerTest, MakeUnassignedThenReassignViaPrefixRule) { + // Test that prefix rules work correctly after unassign+update cycle. + // This covers the interaction between MakeNodeUnassigned, prefix matching, + // and Update. + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input_arg = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output_arg = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& node = graph.AddNode("node", "Abs", "Node", {input_arg}, {output_arg}); + node.SetLayeringAnnotation("Layer_GPU_Compute"); + ASSERT_STATUS_OK(graph.Resolve()); + + // Prefix rule: "Layer_GPU" matches "Layer_GPU_Compute" + LayeringRules rules; + rules.rules.push_back({"GPUDevice", "Layer_GPU", true}); // Index 0, prefix match + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["GPUDevice"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "GPUDevice"; + + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + // Node should be assigned via prefix match + auto assign = index.GetNodeAssignment(graph, node.Index()); + ASSERT_TRUE(assign.has_value()); + EXPECT_EQ(*assign, 0u); + + // Unassign the node + index.MakeNodeUnassigned(graph, node.Index()); + EXPECT_FALSE(index.GetNodeAssignment(graph, node.Index()).has_value()); + + // Add a new node with a different annotation that also matches the prefix + NodeArg* new_out = &graph.GetOrCreateNodeArg("new_output", &type_proto); + Node& new_node = graph.AddNode("new_node", "Abs", "Node with inherited annotation", + {output_arg}, {new_out}); + new_node.SetLayeringAnnotation("Layer_GPU_Memory"); + ASSERT_STATUS_OK(graph.Resolve()); + + std::vector new_nodes = {new_node.Index()}; + index.Update(graph, new_nodes); + + // New node should also be assigned via prefix match + auto new_assign = index.GetNodeAssignment(graph, new_node.Index()); + ASSERT_TRUE(new_assign.has_value()); + EXPECT_EQ(*new_assign, 0u); +} + +TEST(LayeringIndexPartitionerTest, NoLayeringIndexAllNodesVisible) { + // When layering_index is nullptr (no layering configuration), + // all nodes should be visible to all EPs. This verifies the baseline + // behavior that the filtering code path is only active when layering is enabled. + + auto h = SimpleGraphHelper::Create(3); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + auto* node2 = h.graph->GetNode(h.node_indices[2]); + + // Even if nodes have annotations, without a LayeringIndex, everything is visible + node0->SetLayeringAnnotation("RuleA"); + node2->SetLayeringAnnotation("RuleB"); + ASSERT_STATUS_OK(h.graph->Resolve()); + + // Without LayeringIndex, a standard GraphViewer should see all nodes + GraphViewer viewer(*h.graph); + EXPECT_EQ(viewer.NumberOfNodes(), 3); + + // All nodes accessible + EXPECT_NE(viewer.GetNode(h.node_indices[0]), nullptr); + EXPECT_NE(viewer.GetNode(h.node_indices[1]), nullptr); + EXPECT_NE(viewer.GetNode(h.node_indices[2]), nullptr); +} + +TEST(LayeringIndexPartitionerTest, EpWithNoLayeringRulesSeesAllUnassignedNodes) { + // An EP that has no rules in the LayeringIndex (i.e., GetLayeringRulesForThisEp returns nullopt) + // should still see unassigned nodes, but nodes assigned to other EPs are excluded. + // This is the behavior for a CPU fallback EP not mentioned in layering config, + // as implemented in graph_partitioner.cc create_graph_viewer: + // if (!rules_opt || rules_opt->get().count(*rule_idx_opt) == 0) { include = false; } + + auto h = SimpleGraphHelper::Create(4); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + auto* node2 = h.graph->GetNode(h.node_indices[2]); + node0->SetLayeringAnnotation("RuleA"); + node2->SetLayeringAnnotation("RuleB"); + // node1 and node3 are unannotated + ASSERT_STATUS_OK(h.graph->Resolve()); + + auto index = CreateTwoEpIndex(*h.graph, "DeviceA", "RuleA", "DeviceB", "RuleB"); + + // "CPUDevice" has no rules in the index + auto rules_cpu = index.GetLayeringRulesForThisEp("CPUDevice"); + EXPECT_FALSE(rules_cpu.has_value()); + + // Replicate create_graph_viewer filtering logic for an EP with no rules. + // When rules_opt is nullopt, any node with an assignment is excluded: + // if (!rules_opt || ...) { include = false; } + // Unassigned nodes remain included. + InlinedVector filtered_for_cpu; + for (auto& node : h.graph->Nodes()) { + auto rule_idx_opt = index.GetNodeAssignment(*h.graph, node.Index()); + bool include = true; + if (rule_idx_opt) { + if (!rules_cpu || rules_cpu->get().count(*rule_idx_opt) == 0) { + include = false; + } + } + if (include) { + filtered_for_cpu.push_back(&node); + } + } + + // CPUDevice should see only the 2 unassigned nodes (node1, node3). + // node0 (RuleA/DeviceA) and node2 (RuleB/DeviceB) are excluded. + EXPECT_EQ(filtered_for_cpu.size(), 2u); + + bool found[4] = {}; + for (const auto* n : filtered_for_cpu) { + for (size_t i = 0; i < std::size(found); ++i) { + if (n->Index() == h.node_indices[i]) found[i] = true; + } + } + EXPECT_FALSE(found[0]) << "node0 assigned to DeviceA should be excluded"; + EXPECT_TRUE(found[1]) << "node1 unassigned should be included"; + EXPECT_FALSE(found[2]) << "node2 assigned to DeviceB should be excluded"; + EXPECT_TRUE(found[3]) << "node3 unassigned should be included"; +} +TEST(LayeringIndexPartitionerTest, MultipleRulesForSameEp) { + // An EP can have multiple rules assigned to it. All nodes matching any of its + // rules should be visible to it, while nodes matching other EP rules should not. + + auto h = SimpleGraphHelper::Create(4); + auto* node0 = h.graph->GetNode(h.node_indices[0]); + auto* node1 = h.graph->GetNode(h.node_indices[1]); + auto* node2 = h.graph->GetNode(h.node_indices[2]); + + node0->SetLayeringAnnotation("RuleA1"); + node1->SetLayeringAnnotation("RuleA2"); + node2->SetLayeringAnnotation("RuleB"); + // node3 unannotated + ASSERT_STATUS_OK(h.graph->Resolve()); + + // DeviceA has two rules: RuleA1 (index 0) and RuleA2 (index 1) + // DeviceB has one rule: RuleB (index 2) + LayeringRules rules; + rules.rules.push_back({"DeviceA", "RuleA1", false}); // Index 0 + rules.rules.push_back({"DeviceA", "RuleA2", false}); // Index 1 + rules.rules.push_back({"DeviceB", "RuleB", false}); // Index 2 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["DeviceA"].insert(0); + ep_map["DeviceA"].insert(1); + ep_map["DeviceB"].insert(2); + + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "DeviceA"; + rule_map[1] = "DeviceA"; + rule_map[2] = "DeviceB"; + + auto index = LayeringIndex::Create(*h.graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + + auto rules_a = index.GetLayeringRulesForThisEp("DeviceA"); + ASSERT_TRUE(rules_a.has_value()); + EXPECT_EQ(rules_a->get().size(), 2u); // Both rule indices 0 and 1 + + // Simulate filtering for DeviceA + InlinedVector filtered_for_a; + for (auto& node : h.graph->Nodes()) { + auto rule_idx_opt = index.GetNodeAssignment(*h.graph, node.Index()); + bool include = true; + if (rule_idx_opt) { + if (rules_a->get().count(*rule_idx_opt) == 0) { + include = false; + } + } + if (include) { + filtered_for_a.push_back(&node); + } + } + + // DeviceA should see node0, node1 (both its rules), and node3 (unassigned) = 3 nodes + // node2 (RuleB/DeviceB) should be excluded + EXPECT_EQ(filtered_for_a.size(), 3u); + + bool found[4] = {}; + for (const auto* n : filtered_for_a) { + for (int i = 0; i < 4; ++i) { + if (n->Index() == h.node_indices[i]) found[i] = true; + } + } + EXPECT_TRUE(found[0]); // node0 - RuleA1 + EXPECT_TRUE(found[1]); // node1 - RuleA2 + EXPECT_FALSE(found[2]); // node2 - RuleB (excluded) + EXPECT_TRUE(found[3]); // node3 - unassigned +} + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) \ No newline at end of file diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc new file mode 100644 index 0000000000000..a102fe4e7770b --- /dev/null +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/resource_accountant.h" +#include "core/graph/indexed_sub_graph.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" + +#include "gtest/gtest.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/test_environment.h" + +namespace onnxruntime { +namespace test { + +// Test accountant mimicking SizeBasedStatsAccountant ad-hoc path: +// Uses pending/committed weight sets so that: +// - Within a GetCapability pass, shared weights are deduped +// - Across passes, only committed weights persist and pending are discarded +class TestDedupAccountant : public IResourceAccountant { + public: + TestDedupAccountant() = default; + + ResourceCount GetConsumedAmount() const override { + return consumed_; + } + + void AddConsumedAmount(const ResourceCount& amount) noexcept override { + if (std::holds_alternative(amount)) { + consumed_ += std::get(amount); + } + } + + void RemoveConsumedAmount(const ResourceCount& amount) noexcept override { + if (std::holds_alternative(amount)) { + consumed_ -= std::get(amount); + } + } + + ResourceCount ComputeResourceCount(const Node& node) override { + const auto* graph = node.GetContainingGraph(); + if (graph == nullptr) { + return static_cast(0); + } + + size_t total = 0; + for (const auto* input_def : node.InputDefs()) { + if (!input_def->Exists()) { + continue; + } + const auto& name = input_def->Name(); + constexpr bool check_outer_scope = true; + const auto* init = graph->GetInitializer(name, check_outer_scope); + if (init != nullptr) { + if (committed_weights_.count(name) > 0) { + continue; + } + if (pending_weights_.count(name) > 0) { + continue; + } + auto it = weight_sizes_.find(name); + if (it != weight_sizes_.end()) { + total += it->second; + } + pending_weights_.insert(name); + pending_weights_by_node_[node.Index()].insert(name); + } + } + return total; + } + + void ResetPendingWeights() override { + pending_weights_.clear(); + pending_weights_by_node_.clear(); + } + + void CommitWeightsForNode(NodeIndex node_index) override { + auto it = pending_weights_by_node_.find(node_index); + if (it != pending_weights_by_node_.end()) { + for (const auto& name : it->second) { + pending_weights_.erase(name); + } + committed_weights_.insert(it->second.begin(), it->second.end()); + pending_weights_by_node_.erase(it); + } + } + + void RegisterWeight(const std::string& name, size_t size) { + weight_sizes_[name] = size; + } + + size_t GetConsumedSizeT() const { return consumed_; } + + private: + size_t consumed_ = 0; + InlinedHashSet committed_weights_; + InlinedHashSet pending_weights_; + InlinedHashMap> pending_weights_by_node_; + InlinedHashMap weight_sizes_; +}; + +// Two Add nodes that share a single initializer weight_W. +struct SharedWeightGraph { + std::unique_ptr model; + Graph* graph = nullptr; + Node* node_a = nullptr; + Node* node_b = nullptr; + + static SharedWeightGraph Create() { + SharedWeightGraph h; + std::unordered_map dom; + dom[kOnnxDomain] = 12; + h.model = std::make_unique( + "test_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), dom, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + h.graph = &h.model->MainGraph(); + + ONNX_NAMESPACE::TypeProto ft; + ft.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ft.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(250); + + ONNX_NAMESPACE::TensorProto wp; + wp.set_name("weight_W"); + wp.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + wp.add_dims(250); + for (int i = 0; i < 250; ++i) { + wp.add_float_data(0.0f); + } + h.graph->AddInitializedTensor(wp); + + auto* ia = &h.graph->GetOrCreateNodeArg("input_a", &ft); + auto* ib = &h.graph->GetOrCreateNodeArg("input_b", &ft); + auto* wa = &h.graph->GetOrCreateNodeArg("weight_W", &ft); + auto* oa = &h.graph->GetOrCreateNodeArg("out_a", &ft); + auto* ob = &h.graph->GetOrCreateNodeArg("out_b", &ft); + + h.node_a = &h.graph->AddNode("node_A", "Add", "A", {ia, wa}, {oa}); + h.node_b = &h.graph->AddNode("node_B", "Add", "B", {ib, wa}, {ob}); + + auto status = h.graph->Resolve(); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return h; + } +}; + +// Regression: AccountForAllNodes sums pre-stored per-node costs +// that already have correct within-pass weight deduplication. +TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { + auto h = SharedWeightGraph::Create(); + TestDedupAccountant accountant; + accountant.RegisterWeight("weight_W", 1000); + + IndexedSubGraph sub_graph; + sub_graph.nodes.push_back(h.node_a->Index()); + sub_graph.nodes.push_back(h.node_b->Index()); + sub_graph.SetAccountant(&accountant); + + auto cost_a = accountant.ComputeResourceCount(*h.node_a); + sub_graph.AppendNodeCost(cost_a); + EXPECT_EQ(std::get(cost_a), size_t{1000}); + + auto cost_b = accountant.ComputeResourceCount(*h.node_b); + sub_graph.AppendNodeCost(cost_b); + EXPECT_EQ(std::get(cost_b), size_t{0}); + + ASSERT_TRUE(sub_graph.IsAccountingEnabled()); + sub_graph.AccountForAllNodes(); + + EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) + << "AccountForAllNodes should sum pre-stored costs (1000 + 0)"; +} + +// Verifies that ResetPendingWeights + re-probe produces correct results. +// After probing (which only writes to pending), resetting pending and +// re-probing should see the full weight cost again since nothing was committed. +TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { + auto h = SharedWeightGraph::Create(); + TestDedupAccountant accountant; + accountant.RegisterWeight("weight_W", 1000); + + // Probing pass populates pending weights + auto cost_a = accountant.ComputeResourceCount(*h.node_a); + EXPECT_EQ(std::get(cost_a), size_t{1000}); + auto cost_b = accountant.ComputeResourceCount(*h.node_b); + EXPECT_EQ(std::get(cost_b), size_t{0}); + + // Discard the pass (simulating capabilities.clear() before second GetCapability) + accountant.ResetPendingWeights(); + + // Re-probe: weight_W was never committed, so it should be counted again + IndexedSubGraph sub_graph; + sub_graph.nodes.push_back(h.node_a->Index()); + sub_graph.SetAccountant(&accountant); + auto recomputed_cost = accountant.ComputeResourceCount(*h.node_a); + sub_graph.AccountForNode(h.node_a->Index(), recomputed_cost); + + EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) + << "After ResetPendingWeights, re-probe should see full weight cost"; +} + +// Each node has a unique initializer. AccountForAllNodes sums both. +TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { + std::unordered_map dom; + dom[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), dom, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto ft; + ft.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ft.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(100); + + const char* names[] = {"weight_1", "weight_2"}; + for (const char* wn : names) { + ONNX_NAMESPACE::TensorProto tp; + tp.set_name(wn); + tp.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tp.add_dims(100); + for (int i = 0; i < 100; ++i) { + tp.add_float_data(0.0f); + } + graph.AddInitializedTensor(tp); + } + + auto* input = &graph.GetOrCreateNodeArg("input", &ft); + auto* w1 = &graph.GetOrCreateNodeArg("weight_1", &ft); + auto* w2 = &graph.GetOrCreateNodeArg("weight_2", &ft); + auto* out1 = &graph.GetOrCreateNodeArg("out1", &ft); + auto* out2 = &graph.GetOrCreateNodeArg("out2", &ft); + + auto& node1 = graph.AddNode("n1", "Add", "", {input, w1}, {out1}); + auto& node2 = graph.AddNode("n2", "Add", "", {out1, w2}, {out2}); + ASSERT_STATUS_OK(graph.Resolve()); + + TestDedupAccountant accountant; + accountant.RegisterWeight("weight_1", 400); + accountant.RegisterWeight("weight_2", 600); + + IndexedSubGraph sub_graph; + sub_graph.nodes.push_back(node1.Index()); + sub_graph.nodes.push_back(node2.Index()); + sub_graph.SetAccountant(&accountant); + + sub_graph.AppendNodeCost(accountant.ComputeResourceCount(node1)); + sub_graph.AppendNodeCost(accountant.ComputeResourceCount(node2)); + + ASSERT_TRUE(sub_graph.IsAccountingEnabled()); + sub_graph.AccountForAllNodes(); + + EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) + << "No shared weights: should sum all costs (400 + 600)"; +} + +// AccountForNode per-node and AccountForAllNodes bulk produce same result. +TEST(ResourceAccountantTest, AccountForNode_MatchesAccountForAllNodes) { + auto h = SharedWeightGraph::Create(); + + // Per-node path + TestDedupAccountant acc1; + acc1.RegisterWeight("weight_W", 1000); + IndexedSubGraph sub1; + sub1.nodes.push_back(h.node_a->Index()); + sub1.nodes.push_back(h.node_b->Index()); + sub1.SetAccountant(&acc1); + sub1.AppendNodeCost(acc1.ComputeResourceCount(*h.node_a)); + sub1.AppendNodeCost(acc1.ComputeResourceCount(*h.node_b)); + sub1.AccountForNode(0); + sub1.AccountForNode(1); + size_t per_node = acc1.GetConsumedSizeT(); + + // Bulk path + TestDedupAccountant acc2; + acc2.RegisterWeight("weight_W", 1000); + IndexedSubGraph sub2; + sub2.nodes.push_back(h.node_a->Index()); + sub2.nodes.push_back(h.node_b->Index()); + sub2.SetAccountant(&acc2); + sub2.AppendNodeCost(acc2.ComputeResourceCount(*h.node_a)); + sub2.AppendNodeCost(acc2.ComputeResourceCount(*h.node_b)); + sub2.AccountForAllNodes(); + size_t bulk = acc2.GetConsumedSizeT(); + + EXPECT_EQ(per_node, bulk) + << "Per-node and bulk should produce identical results"; + EXPECT_EQ(per_node, size_t{1000}); +} + +// Cross-subgraph dedup: EP1 commits node_A, EP2 probes node_B and +// correctly sees weight_W as already accounted. +TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { + auto h = SharedWeightGraph::Create(); + TestDedupAccountant accountant; + accountant.RegisterWeight("weight_W", 1000); + + // EP1 probes and commits node_A + IndexedSubGraph sub1; + sub1.nodes.push_back(h.node_a->Index()); + sub1.SetAccountant(&accountant); + sub1.AppendNodeCost(accountant.ComputeResourceCount(*h.node_a)); + sub1.AccountForNode(0); + EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}); + + // EP2 probes node_B: weight_W already committed + auto cost_b = accountant.ComputeResourceCount(*h.node_b); + EXPECT_EQ(std::get(cost_b), size_t{0}) + << "weight_W was committed by EP1, should be deduped for EP2"; + + // EP2 commits node_B with cost 0 + IndexedSubGraph sub2; + sub2.nodes.push_back(h.node_b->Index()); + sub2.SetAccountant(&accountant); + sub2.AppendNodeCost(cost_b); + sub2.AccountForNode(0); + + EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) + << "Total should still be 1000 - weight_W counted once across both"; +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index ed2b98e5280b5..656b0ef86289d 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include @@ -9,9 +10,11 @@ #include "core/framework/execution_providers.h" #include "core/framework/graph_partitioner.h" #include "core/framework/kernel_registry.h" +#include "core/framework/layering_annotations.h" #include "core/framework/op_kernel.h" #include "core/framework/bfc_arena.h" #include "core/framework/ep_context_options.h" +#include "core/framework/resource_accountant.h" #include "core/framework/session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -280,7 +283,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); }, sess_options.config_options, - DefaultLoggingManager().DefaultLogger())); + DefaultLoggingManager().DefaultLogger(), nullptr /*layering_index*/)); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -367,7 +370,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { cpu_allocator, debug_graph_fn); }, sess_options.config_options, - default_logger)); + default_logger, + nullptr /*layering_index*/)); EXPECT_STATUS_OK(session_state.FinalizeSessionState(model.ModelPath(), krm)); @@ -414,9 +418,50 @@ namespace { using ParitionVerifierFn = std::function; +// Collect unique node names from a graph and all its subgraphs +// using the same naming scheme as the resource accountant. +static void CollectNodeNames(const Graph& graph, std::vector& names) { + for (const auto& node : graph.Nodes()) { + names.push_back(IResourceAccountant::MakeUniqueNodeName(node)); + for (const auto& [_, subgraph] : node.GetAttributeNameToSubgraphMap()) { + CollectNodeNames(*subgraph, names); + } + } +} + +// Generates a node stats file dynamically from the current graph, +// assigning each node a fixed cost. Returns the total cost across +// all nodes so callers can choose a threshold relative to the actual total. +// This avoids relying on a pre-baked stats file whose node name hashes +// become stale when graph optimizers change node input/output names. +static void GenerateDynamicNodeStatsFile(const ORTCHAR_T* model_path, + const std::filesystem::path& output_path, + size_t& total_cost, + size_t cost_per_node = 1024) { + const auto& default_logger = DefaultLoggingManager().DefaultLogger(); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, default_logger)); + Graph& graph = model->MainGraph(); + ASSERT_STATUS_OK(graph.Resolve()); + + std::vector node_names; + CollectNodeNames(graph, node_names); + + std::ofstream ofs(output_path); + ASSERT_TRUE(ofs.is_open()); + ofs << "#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations\n"; + for (const auto& name : node_names) { + ofs << name << "," << cost_per_node << ",0,0,0\n"; + } + ofs.close(); + + total_cost = node_names.size() * cost_per_node; +} + void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, const SessionOptions& sess_options, - const ParitionVerifierFn& verifier_fn) { + const ParitionVerifierFn& verifier_fn, + const std::string& layering_config = std::string()) { const auto& log_manager = DefaultLoggingManager(); log_manager.SetDefaultLoggerSeverity(onnxruntime::logging::Severity::kVERBOSE); const auto& default_logger = log_manager.DefaultLogger(); @@ -431,9 +476,12 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); ExecutionProviders execution_providers; - auto tmp_cpu_execution_provider = DefaultCudaExecutionProvider(); - tmp_cpu_execution_provider->SetLogger(&default_logger); - ASSERT_STATUS_OK(execution_providers.Add(kCudaExecutionProvider, std::move(tmp_cpu_execution_provider))); + auto tmp_execution_provider = DefaultCudaExecutionProvider(); + tmp_execution_provider->SetLogger(&default_logger); + ASSERT_STATUS_OK(execution_providers.Add(kCudaExecutionProvider, std::move(tmp_execution_provider))); + tmp_execution_provider = DefaultCpuExecutionProvider(); + tmp_execution_provider->SetLogger(&default_logger); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_execution_provider))); KernelRegistryManager krm; ASSERT_STATUS_OK(krm.RegisterKernels(execution_providers)); @@ -445,6 +493,16 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, SessionState session_state(model->MainGraph(), execution_providers, tp.get(), nullptr, dtm, edlm, default_logger, profiler, sess_options); + LayeringIndex* layering_index = nullptr; + std::optional layering_index_storage; + if (!layering_config.empty()) { + ASSERT_STATUS_OK(LayeringIndex::Create(graph, layering_config, {}, execution_providers, + default_logger, layering_index_storage)); + if (layering_index_storage.has_value()) { + layering_index = &layering_index_storage.value(); + } + } + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup auto graph_optimizer_registry = std::make_unique(&sess_options, execution_providers.Get(onnxruntime::kCpuExecutionProvider), @@ -455,7 +513,8 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, layout_transformation::DebugGraphFn debug_graph_fn; ASSERT_STATUS_OK( partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, - sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal, + sess_options.config_options, default_logger, layering_index, + GraphPartitioner::Mode::kNormal, epctx::ModelGenOptions{}, debug_graph_fn)); @@ -484,16 +543,28 @@ TEST(SessionStateTest, TestResourceAwarePartitioning_NoLimit) { TEST(SessionStateTest, TestResourceAwarePartitioning_LargeLimit) { constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); - constexpr const char* limit_setting = "10000,tiny_gpt2_beamsearch_node_stats.txt"; + std::error_code ec; + const std::filesystem::path stats_path = + std::filesystem::temp_directory_path(ec) / "tiny_gpt2_beamsearch_dynamic_stats_large.txt"; + ASSERT_FALSE(ec) << "temp_directory_path failed: " << ec.message(); + + // Generate node stats dynamically so names always match the current graph + constexpr size_t cost_per_node = 1024; + size_t total_cost = 0; + GenerateDynamicNodeStatsFile(model_path, stats_path, total_cost, cost_per_node); + ASSERT_GT(total_cost, 0U); + + // Use a limit much larger than total cost so all nodes are assigned CUDA. + size_t large_limit_kb = (total_cost * 2) / 1024 + 1; + std::string limit_setting = std::to_string(large_limit_kb) + "," + stats_path.string(); - // Large limit, all nodes are still assigned SessionOptions sess_options; sess_options.enable_mem_pattern = false; sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = false; ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry( - kOrtSessionOptionsResourceCudaPartitioningSettings, limit_setting)); + kOrtSessionOptionsResourceCudaPartitioningSettings, limit_setting.c_str())); LoadWithResourceAwarePartitioning(model_path, sess_options, [](const Graph& graph) { const auto& graph_nodes = graph.Nodes(); @@ -501,20 +572,36 @@ TEST(SessionStateTest, TestResourceAwarePartitioning_LargeLimit) { EXPECT_EQ(node.GetExecutionProviderType(), kCudaExecutionProvider); } }); + + std::error_code remove_ec; + std::filesystem::remove(stats_path, remove_ec); } TEST(SessionStateTest, TestResourceAwarePartitioning_CPUOffloaded) { constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); - constexpr const char* limit_setting = "5000,tiny_gpt2_beamsearch_node_stats.txt"; + std::error_code ec; + const std::filesystem::path stats_path = + std::filesystem::temp_directory_path(ec) / "tiny_gpt2_beamsearch_dynamic_stats_offload.txt"; + ASSERT_FALSE(ec) << "temp_directory_path failed: " << ec.message(); + + // Generate node stats dynamically so names always match the current graph. + constexpr size_t cost_per_node = 1024; + size_t total_cost = 0; + GenerateDynamicNodeStatsFile(model_path, stats_path, total_cost, cost_per_node); + ASSERT_GT(total_cost, 0U); + + // Set threshold to half the total cost so some nodes must be offloaded to CPU. + size_t half_limit_kb = (total_cost / 2) / 1024; + ASSERT_GT(half_limit_kb, 0U); + std::string limit_setting = std::to_string(half_limit_kb) + "," + stats_path.string(); - // Large limit, all nodes are still assigned SessionOptions sess_options; sess_options.enable_mem_pattern = false; sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = false; ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry( - kOrtSessionOptionsResourceCudaPartitioningSettings, limit_setting)); + kOrtSessionOptionsResourceCudaPartitioningSettings, limit_setting.c_str())); LoadWithResourceAwarePartitioning(model_path, sess_options, [](const Graph& graph) { const auto& graph_nodes = graph.Nodes(); @@ -527,6 +614,38 @@ TEST(SessionStateTest, TestResourceAwarePartitioning_CPUOffloaded) { } EXPECT_TRUE(cpu_node_found); }); + + std::error_code remove_ec; + std::filesystem::remove(stats_path, remove_ec); +} + +TEST(SessionStateTest, TestLayeringPartitioning) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/layering/tiny_gpt2_beamsearch_layering.onnx"); + constexpr const char* layering_setting = + "cpu(Embed,Decode);gpu(GptAttention0,GptAttention1,GptAttention2,GptAttention3,GptAttention4)"; + + // Set the session options for layering + SessionOptions sess_options; + sess_options.enable_mem_pattern = false; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = false; + ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry( + kOrtSessionOptionsLayerAssignmentSettings, layering_setting)); + + LoadWithResourceAwarePartitioning(model_path, sess_options, [](const Graph& graph) { + const auto& graph_nodes = graph.Nodes(); + for (const auto& node : graph_nodes) { + const std::string& name = node.Name(); + const bool expected_on_cpu = (name.find("EmbedLayer") == 0) || (name == "LayerNorm_10") || (name == "MatMul_1165"); + + const std::string& ep = node.GetExecutionProviderType(); + if (expected_on_cpu) { + EXPECT_EQ(ep, kCpuExecutionProvider) << "Node " << name << " expected on CPU but found on " << ep; + } else { + EXPECT_EQ(ep, kCudaExecutionProvider) << "Node " << name << " expected on CUDA but found on " << ep; + } + } }, layering_setting); } #endif // USE_CUDA @@ -909,9 +1028,8 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) { OrtMemoryInfo mem_info(CPU, OrtDeviceAllocator); std::vector float_data(1, 1); auto value = std::make_unique(); - Tensor::InitOrtValue(DataTypeImpl::GetType(), - TensorShape(std::vector{1}), reinterpret_cast(float_data.data()), - mem_info, *value); + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(std::vector{1}), + float_data.data(), mem_info, *value); ASSERT_STATUS_OK(sess_options.AddInitializer("node_0_input_1", value.get())); @@ -1379,6 +1497,5 @@ INSTANTIATE_TEST_SUITE_P(SessionStateTests, PrepackingTestParam{true, false}, PrepackingTestParam{true, true})); #endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 8c5859823ac16..572fb6992ec76 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -728,6 +728,64 @@ TEST_F(PathValidationTest, ValidateExternalDataPathEmptyModelPathWithSymlinkOuts EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("escapes working directory")); } +TEST(TensorProtoUtilsTest, GetNodeProtoLayeringAnnotation) { + // Case 1: Annotation exists + { + ONNX_NAMESPACE::NodeProto node_proto; + node_proto.set_name("test_node"); + auto* prop = node_proto.add_metadata_props(); + prop->set_key(utils::kNodeProtoLayerAnnotation); + prop->set_value("foo"); + + auto result = utils::GetNodeProtoLayeringAnnotation(node_proto); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "foo"); + } + + // Case 2: Annotation missing (empty metadata_props) + { + ONNX_NAMESPACE::NodeProto node_proto; + node_proto.set_name("test_node"); + + auto result = utils::GetNodeProtoLayeringAnnotation(node_proto); + EXPECT_FALSE(result.has_value()); + } + + // Case 3: Other metadata exists, but not the annotation + { + ONNX_NAMESPACE::NodeProto node_proto; + node_proto.set_name("test_node"); + auto* prop = node_proto.add_metadata_props(); + prop->set_key("some_other_key"); + prop->set_value("some_value"); + + auto result = utils::GetNodeProtoLayeringAnnotation(node_proto); + EXPECT_FALSE(result.has_value()); + } + + // Case 4: Multiple metadata, including the annotation + { + ONNX_NAMESPACE::NodeProto node_proto; + node_proto.set_name("test_node"); + + auto* prop1 = node_proto.add_metadata_props(); + prop1->set_key("some_other_key"); + prop1->set_value("some_value"); + + auto* prop2 = node_proto.add_metadata_props(); + prop2->set_key(utils::kNodeProtoLayerAnnotation); + prop2->set_value("bar"); + + auto* prop3 = node_proto.add_metadata_props(); + prop3->set_key("yet_another_key"); + prop3->set_value("baz"); + + auto result = utils::GetNodeProtoLayeringAnnotation(node_proto); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "bar"); + } +} + // Tests for ValidateEmbeddedTensorProtoDataSizeAndShape and embedded initializer size limits TEST(TensorProtoDataSizeShapeValidationTest, ValidTensorProtoWithRawData) { diff --git a/onnxruntime/test/testdata/layering/tiny_gpt2_beamsearch_layering.onnx b/onnxruntime/test/testdata/layering/tiny_gpt2_beamsearch_layering.onnx new file mode 100644 index 0000000000000..57efb4ebe11a3 Binary files /dev/null and b/onnxruntime/test/testdata/layering/tiny_gpt2_beamsearch_layering.onnx differ diff --git a/onnxruntime/test/testdata/layering/tiny_gpt2_beamsearch_layering.txt b/onnxruntime/test/testdata/layering/tiny_gpt2_beamsearch_layering.txt new file mode 100644 index 0000000000000..5affbde73e5b3 --- /dev/null +++ b/onnxruntime/test/testdata/layering/tiny_gpt2_beamsearch_layering.txt @@ -0,0 +1,55 @@ +Embed:EmbedLayer +GptAttention0:GptAttention_0 +GptAttention0:Add_295 +GptAttention0:LayerNorm_1 +GptAttention0:FullyConnect_MatMul_0 +GptAttention0:FastGelu_AddBias_0 +GptAttention0:FullyConnect_MatMul_1 +GptAttention0:FullyConnect_Add_1 +GptAttention0:Add_360 +GptAttention1:LayerNorm_2 +GptAttention1:GptAttention_1 +GptAttention1:Add_492 +GptAttention1:FullyConnect_MatMul_2 +GptAttention1:FastGelu_AddBias_1 +GptAttention1:FullyConnect_MatMul_3 +GptAttention1:FullyConnect_Add_3 +GptAttention1:Add_557 +GptAttention2:LayerNorm_4 +GptAttention2:GptAttention_2 +GptAttention2:Add_689 +GptAttention2:LayerNorm_5 +GptAttention2:FullyConnect_MatMul_4 +GptAttention2:FastGelu_AddBias_2 +GptAttention2:FullyConnect_MatMul_5 +GptAttention2:FullyConnect_Add_5 +GptAttention2:Add_754 +GptAttention3:LayerNorm_6 +GptAttention3:GptAttention_3 +GptAttention3:Add_886 +GptAttention3:LayerNorm_7 +GptAttention3:FullyConnect_MatMul_6 +GptAttention3:FastGelu_AddBias_3 +GptAttention3:FullyConnect_MatMul_7 +GptAttention3:FullyConnect_Add_7 +GptAttention3:Add_951 +GptAttention4:LayerNorm_8 +GptAttention4:GptAttention_4 +GptAttention4:Add_1083 +GptAttention4:LayerNorm_9 +GptAttention4:FullyConnect_MatMul_8 +GptAttention4:FastGelu_AddBias_4 +GptAttention4:FullyConnect_MatMul_9 +GptAttention4:FullyConnect_Add_9 +GptAttention4:Add_1148 +Decode:LayerNorm_10 +Decode:MatMul_1165 + + + + + + + + +