diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 96e513c8a7bc9..36ba8db9bdc75 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1252,7 +1252,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
onnx_test_runner_common onnxruntime_test_utils onnxruntime_common
onnxruntime onnxruntime_flatbuffers onnx_test_data_proto
${onnxruntime_EXTERNAL_LIBRARIES}
- ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
+ absl::flags absl::flags_parse ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
if(NOT WIN32)
if(onnxruntime_USE_SNPE)
list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe)
@@ -1272,7 +1272,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32)
endif()
else()
- target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs})
+ target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common absl::flags absl::flags_parse ${onnx_test_libs})
endif()
set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index 7df3368ad4e0b..1bb7f219c9a45 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -179,7 +179,12 @@ class IExecutionProvider {
/**
Get the device id of current execution provider
*/
- virtual int GetDeviceId() const { return default_device_.Id(); };
+ virtual int GetDeviceId() const { return default_device_.Id(); }
+
+ /**
+ * Get the OrtDevice the execution provider was registered with.
+ */
+ const OrtDevice& GetDevice() const { return default_device_; }
/**
Get execution provider's configuration options.
diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h
index 536d641b4eef9..fea970b84fd84 100644
--- a/include/onnxruntime/core/framework/ortdevice.h
+++ b/include/onnxruntime/core/framework/ortdevice.h
@@ -150,6 +150,13 @@ struct OrtDevice {
return alignment < other.alignment;
}
+ bool EqualIgnoringAlignment(const OrtDevice& other) const {
+ return device_type == other.device_type &&
+ memory_type == other.memory_type &&
+ vendor_id == other.vendor_id &&
+ device_id == other.device_id;
+ }
+
private:
// Device type.
int32_t device_type : 8;
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index bd694f7b3b23c..866892979b749 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -1220,7 +1220,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
#endif
#if !defined(ORT_MINIMAL_BUILD)
- /** Gets the GraphProto representation of this Graph only. */
+ /** Gets the GraphProto representation of this Graph only.
+ * This does not remove in-memory tags for graph initializers.
+ * Use ToGraphProto() const to get a GraphProto that can be serialized externally.
+ */
const ONNX_NAMESPACE::GraphProto& ToGraphProto();
///
@@ -1439,6 +1442,27 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return Resolve(default_options);
}
+ ///
+ /// This function converts all the graph TensorProto initializers into OrtValues
+ /// and creates a in-memory external data reference for each OrtValue.
+ ///
+ ///
+ Status ConvertInitializersIntoOrtValues();
+
+ /**
+ * @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto.
+ *
+ * This function converts specified TensorProto initializers in the graph into OrtValues and
+ * creates in-memory external data references for each OrtValue. It then updates the provided
+ * GraphProto with the modified initializers.
+ *
+ * @param iterators Span of iterators pointing to the initializers and the order that should be processed
+ * @param output_graph_proto The GraphProto to be updated with the modified initializers
+ * @return Status Returns a Status object indicating success or any errors that occurred during conversion
+ */
+ Status RegenerateInitializersAndReplaceInMemory(gsl::span iterators,
+ ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
+
const std::unordered_set& GetOuterScopeNodeArgNames() const noexcept {
return outer_scope_node_arg_names_;
}
@@ -1595,20 +1619,25 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
/// This function is used by ToGraphProto() to ensure in-memory external data references
/// don't leak externally since they are non-standard.
///
- /// It handles two scenarios:
- /// - When GraphSynchronizationNeeded() is false: GraphProto is simply copied
+ /// It is used when GraphSynchronizationNeeded() is false: GraphProto is simply copied
/// from graph_proto_ by ToGraphProto(). This copy includes both main graph
/// and subgraph initializers. This function examines all initializers
/// and inlines any in-memory data references.
- /// - When GraphSynchronizationNeeded() is true: ToGraphProto() generates a new GraphProto
- /// using ToGraphProtoInternal(). This doesn't transfer main graph initializers, which are
- /// copied and inlined by ToGraphProto() itself. This function processes only the subgraph initializers
- /// as needed.
///
/// The GraphProto to process
- /// Whether to process the main graph initializers
- /// Status indicating success or failure ///
- Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto, bool process_main) const;
+ /// Status indicating success or failure
+ Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
+
+ ///
+ /// This function replaces all of the initializers within output_graph_proto
+ /// from this Graph instance. All in memory initializers are regenerated and inlined.
+ /// This is necessary even if the graph_proto_ is already up to date because initializers() may
+ /// contain obsolete initializers that are no longer in use due to optimizations and contain obsolete
+ /// references to OrtValues that may no longer be around (since we like appending rather than replacing).
+ ///
+ /// Destination GraphProto to receive the updated initializers.
+ /// Status indicating success or failure.
+ Status RegenerateInitializersAndReplaceInMemory(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
///
/// This function traverses the graph bottom up and externalizes
diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h
index 306f81df38e48..89467f5238fa9 100644
--- a/include/onnxruntime/core/session/environment.h
+++ b/include/onnxruntime/core/session/environment.h
@@ -106,6 +106,15 @@ class Environment {
return shared_allocators_;
}
+ /**
+ * Returns an AllocatorPtr for a shared IAllocator based allocator if it matches the memory info.
+ * The OrtMemoryInfo name and whether it's an arena or device allocator is ignored in the lookup, as is the
+ * alignment.
+ * The user calling this function is not expected to know the alignment, and we expect the allocator instance to be
+ * created with a valid alignment for the device.
+ */
+ AllocatorPtr GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const;
+
/**
* Removes registered allocator that was previously registered for sharing between multiple sessions.
*/
@@ -171,7 +180,7 @@ class Environment {
std::unique_ptr inter_op_thread_pool_;
bool create_global_thread_pools_{false};
- std::mutex mutex_;
+ mutable std::mutex mutex_;
// shared allocators from various sources.
// CreateAndRegisterAllocator[V2]: IAllocator allocators created by ORT
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 8ae7535da4413..e4f8cd6df678e 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -666,12 +666,16 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
// Set attributes.
proto.clear_attribute();
- for (const auto& attribute : attributes_) {
+ for (const auto& [name, attribute] : attributes_) {
const gsl::not_null attr{proto.add_attribute()};
- *attr = attribute.second; // copy
- if (update_subgraphs && attr->has_g()) {
+ *attr = attribute; // copy
+ if (update_subgraphs && utils::HasGraph(*attr)) {
+ auto find_hit = attr_to_subgraph_map_.find(name);
+ // Force ToGraphProto() const to be called so
+ // that any in-memory TensorProto initializers go back to being inlined
+ const Graph& subgraph = *find_hit->second;
attr->clear_g();
- *attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto();
+ *attr->mutable_g() = subgraph.ToGraphProto();
}
}
@@ -3381,7 +3385,12 @@ Status Graph::Resolve(const ResolveOptions& options) {
return Status::OK(); };
- ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func));
+ return ForThisAndAllSubgraphs(all_subgraphs, finalize_func);
+}
+
+Status Graph::ConvertInitializersIntoOrtValues() {
+ std::vector all_subgraphs;
+ FindAllSubgraphs(all_subgraphs);
auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
@@ -4308,11 +4317,47 @@ Status InlineOrCopyInitializer(const Graph& src_graph, const ONNX_NAMESPACE::Ten
}
return Status::OK();
}
-
} // namespace
-Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto,
- bool process_main) const {
+Status Graph::RegenerateInitializersAndReplaceInMemory(gsl::span iterators,
+ ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
+ auto& mutable_initializers = *output_graph_proto.mutable_initializer();
+
+#if !defined(DISABLE_SPARSE_TENSORS)
+ output_graph_proto.clear_sparse_initializer();
+
+ const auto& model_path = ModelPath();
+ const bool has_sparse_initializers = !sparse_tensor_names_.empty();
+ const auto sparse_end = sparse_tensor_names_.end();
+
+ for (const auto& iter : iterators) {
+ const auto& [name, tensor_proto] = *iter;
+ const auto& initializer = *tensor_proto;
+ if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(name)) {
+ ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
+ *mutable_initializers.Add()));
+ } else {
+ auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
+ if (utils::HasExternalDataInMemory(initializer)) {
+ ONNX_NAMESPACE::TensorProto tensor_proto_inlined;
+ ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
+ tensor_proto_inlined));
+ ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto_inlined, model_path, sparse_initializer));
+ } else {
+ ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
+ }
+ }
+ }
+#else
+ for (const auto& iter : iterators) {
+ const auto& [name, tensor_proto] = *iter;
+ ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, *tensor_proto, *mutable_initializers.Add()));
+ }
+#endif
+ return Status::OK();
+}
+
+Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
for (const auto& node : Nodes()) {
if (node.ContainsSubgraph()) {
// Let's find this node in the output_graph_proto
@@ -4343,103 +4388,48 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr
"Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ",
node.Name(), " while attempting to recurse into it.");
auto& result_subgraph = *sub_hit->mutable_g();
- ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph, process_main));
+ ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph));
}
}
}
- // When graph_proto is copied from graph_proto, initializers already present in the main graph
- if (parent_graph_ != nullptr || process_main) {
-#if !defined(DISABLE_SPARSE_TENSORS)
- auto* mutable_initializers = output_graph_proto.mutable_initializer();
- const auto& model_path = ModelPath();
- const bool has_sparse_initializers = !sparse_tensor_names_.empty();
- const auto sparse_end = sparse_tensor_names_.end();
-
- // We want to make sure that sparse initializers do not appear
- // as dense duplicates within the initializers list.
- std::optional> initializer_to_remove;
- if (has_sparse_initializers) {
- // We need to remove the dense initializers that are sparse tensors
- initializer_to_remove.emplace();
- }
-
- for (auto first = mutable_initializers->begin(), end = mutable_initializers->end(); first != end; ++first) {
- auto& initializer = *first;
- if (utils::HasExternalDataInMemory(initializer)) {
- // If the initializer has external data in memory, we need to inline it.
- ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
- }
- if (has_sparse_initializers && sparse_end != sparse_tensor_names_.find(initializer.name())) {
- auto& sparse_initializer = *output_graph_proto.add_sparse_initializer();
- ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
- initializer_to_remove->insert(initializer.name());
- }
- }
-
- // erase/remove dense initializers that are sparse tensors so no duplicates are present
- if (initializer_to_remove && !initializer_to_remove->empty()) {
- mutable_initializers->erase(std::remove_if(
- mutable_initializers->begin(), mutable_initializers->end(),
- [&initializer_to_remove](const ONNX_NAMESPACE::TensorProto& initializer) {
- return initializer_to_remove->count(initializer.name()) > 0;
- }),
- mutable_initializers->end());
- }
-#else
- for (auto& initializer : *output_graph_proto.mutable_initializer()) {
- if (utils::HasExternalDataInMemory(initializer)) {
- // If the initializer has external data in memory, we need to inline it.
- ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer));
- }
+ // Filter in iterators for weights that are present in the name_to_initial_tensor_ map
+ // and preserve the order. This is needed for tests.
+ InlinedVector initializers_to_process;
+ initializers_to_process.reserve(name_to_initial_tensor_.size());
+ for (const auto& tensor_proto : output_graph_proto.initializer()) {
+ auto hit = name_to_initial_tensor_.find(tensor_proto.name());
+ if (hit != name_to_initial_tensor_.end()) {
+ initializers_to_process.push_back(hit);
}
-#endif
}
- return Status::OK();
+
+ output_graph_proto.clear_initializer();
+ return RegenerateInitializersAndReplaceInMemory(initializers_to_process, output_graph_proto);
}
ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const {
GraphProto result;
if (!GraphProtoSyncNeeded()) {
result = *graph_proto_;
- ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ true));
+ ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result));
} else {
+ // Recursion is handled via Node::ToProto() const -> Graph::ToGraphProto() const (this method)
+ // so below we handle this graph only.
ToGraphProtoInternal(result);
- ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ false));
-
- // Add initializers to parent graph by copy converting them from graph_proto_
- // ToGraphProtoInternal() does not copy initializers for the main graph
- auto* mutable_initializers = result.mutable_initializer();
-
-#if !defined(DISABLE_SPARSE_TENSORS)
- const auto& model_path = ModelPath();
- const bool has_sparse_initializers = !sparse_tensor_names_.empty();
- const auto sparse_end = sparse_tensor_names_.end();
-
- for (const auto& initializer : graph_proto_->initializer()) {
- if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) {
- ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
- *mutable_initializers->Add()));
- } else {
- auto& sparse_initializer = *result.add_sparse_initializer();
- if (utils::HasExternalDataInMemory(initializer)) {
- ONNX_NAMESPACE::TensorProto tensor_proto;
- ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer,
- tensor_proto));
- ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto, model_path, sparse_initializer));
- } else {
- ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer));
- }
+ InlinedVector initializers_to_process;
+ initializers_to_process.reserve(name_to_initial_tensor_.size());
+ for (const auto& tensor_proto : graph_proto_->initializer()) {
+ auto hit = name_to_initial_tensor_.find(tensor_proto.name());
+ if (hit != name_to_initial_tensor_.end()) {
+ initializers_to_process.push_back(hit);
}
}
-#else
- for (const auto& initializer : graph_proto_->initializer()) {
- ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer, *mutable_initializers->Add()));
- }
-#endif
- }
+ ORT_THROW_IF_ERROR(RegenerateInitializersAndReplaceInMemory(initializers_to_process,
+ result));
+ }
return result;
}
@@ -5235,23 +5225,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod
tensor_proto.set_name(std::string(new_name.value()));
}
- // In the constant node, we won't have symbolic dims.
- const auto tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto);
- auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType();
- const size_t size_in_bytes = Tensor::CalculateTensorStorageSize(ml_data, tensor_shape);
-
- if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
- OrtValue ort_value;
- ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), tensor_proto,
- CPUAllocator::DefaultInstance(), ort_value));
-
- constexpr const bool use_tensor_buffer_true = true;
- auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(),
- use_tensor_buffer_true);
- ORT_RETURN_IF_ERROR(AddInitializedOrtValue(tensor_proto_to_add, ort_value));
- } else {
- AddInitializedTensor(tensor_proto);
- }
+ AddInitializedTensor(tensor_proto);
if (GetNodeArg(tensor_proto.name()) == nullptr) {
TypeProto t{utils::TypeProtoFromTensorProto(tensor_proto)};
diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc
index 616bc1257676f..3f9b58f71bd23 100644
--- a/onnxruntime/core/optimizer/attention_fusion.cc
+++ b/onnxruntime/core/optimizer/attention_fusion.cc
@@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow(element_count) * sizeof(MLFloat16));
}
- return graph_utils::AddInitializerWithExternalData(graph, initializer);
+ return graph_utils::AddInitializer(graph, initializer);
}
static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type,
diff --git a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc
index a98d0ea6f978b..86a7a4d6afbf8 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc
+++ b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc
@@ -189,7 +189,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph,
"total_count: ", total_count, " values.size(): ", values.size());
utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t));
- return &graph_utils::AddInitializerWithExternalData(graph, const_tensor);
+ return &graph_utils::AddInitializer(graph, const_tensor);
}
NodeArg* InsertNodesForValidIndices(Graph& graph,
diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc
index 3d838d8aacfbb..16e8955cb4486 100644
--- a/onnxruntime/core/optimizer/constant_folding.cc
+++ b/onnxruntime/core/optimizer/constant_folding.cc
@@ -95,7 +95,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
ONNX_NAMESPACE::TensorShapeProto result_shape;
result_shape.add_dim()->set_dim_value(clamped_slice_length);
constant_arg_out->SetShape(result_shape);
- graph_utils::AddInitializerWithExternalData(graph, shape_constant);
+ graph_utils::AddInitializer(graph, shape_constant);
}
return is_concrete_shape; // convert to constant if this is true
@@ -317,11 +317,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph.
auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx];
const Tensor& out_tensor = ort_value.Get();
- constexpr const bool use_tensor_buffer_true = true;
+ constexpr const bool use_tensor_buffer_false = false;
ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(
out_tensor,
constant_arg_out->Name(),
- use_tensor_buffer_true);
+ use_tensor_buffer_false);
ONNX_NAMESPACE::TensorShapeProto result_shape;
for (auto& dim : out_tensor.Shape().GetDims()) {
@@ -329,12 +329,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
constant_arg_out->SetShape(result_shape);
- // The data is too small and has been inlined.
- if (!utils::HasExternalData(out_tensorproto)) {
- ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, OrtValue()));
- } else {
- ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, ort_value));
- }
+ graph.AddInitializedTensor(out_tensorproto);
}
}
}
diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc
index c349adfccce53..6478fa7d29d5b 100644
--- a/onnxruntime/core/optimizer/conv_add_fusion.cc
+++ b/onnxruntime/core/optimizer/conv_add_fusion.cc
@@ -79,7 +79,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
auto new_name = graph.GenerateNodeArgName("ConvAddFusion_B_" + B_input_name);
new_conv_B_tensor_proto.set_name(new_name);
- NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
+ NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg);
} else {
@@ -94,7 +94,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
auto new_name = graph.GenerateNodeArgName("ConvAddFusion_Add_B_" + add_B_tensor_proto->name());
new_conv_B_tensor_proto.set_name(new_name);
- NodeArg& new_add_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
+ NodeArg& new_add_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
graph_utils::AddNodeInput(node, 2, new_add_B_node_arg);
}
diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc
index 8bf5420baddde..a14639631d7a1 100644
--- a/onnxruntime/core/optimizer/conv_bn_fusion.cc
+++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc
@@ -120,10 +120,10 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
new_conv_W_tensor_proto.set_name(new_W_name);
new_conv_B_tensor_proto.set_name(new_B_name);
- NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto);
+ NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto);
graph_utils::ReplaceNodeInput(node, 1, new_conv_W_node_arg);
- auto& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
+ auto& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
if (conv_inputs.size() == 3) {
graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg);
diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc
index dc50a150537f7..e91a00729e9db 100644
--- a/onnxruntime/core/optimizer/conv_mul_fusion.cc
+++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc
@@ -90,7 +90,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
new_conv_W_tensor_proto.set_name(new_W_name);
// Replace initializers of conv node
- NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto);
+ NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto);
graph_utils::ReplaceNodeInput(conv_node, 1, new_conv_W_node_arg);
if (is_3d) {
@@ -100,7 +100,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
auto new_B_name = graph.GenerateNodeArgName("ConvMulFusion_Mul_B_" + mul_B_tensor_proto->name());
new_conv_B_tensor_proto.set_name(new_B_name);
- NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto);
+ NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
graph_utils::ReplaceNodeInput(conv_node, 2, new_conv_B_node_arg);
}
diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
index 7f214e656e0ab..96f75f07e32e1 100644
--- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
+++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
@@ -53,7 +53,7 @@ static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index,
auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name());
new_input_tensor.set_name(new_name);
new_input_tensor.add_dims(1);
- NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor);
+ NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor);
graph_utils::ReplaceNodeInput(node, index, new_input);
}
diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc
index ad25f95ac1186..f8fd807084d38 100644
--- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc
+++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc
@@ -474,7 +474,7 @@ static NodeArg* ExtractEmbedding(Graph& graph,
utils::SetRawDataInTensorProto(initializer, data, gsl::narrow(element_count) * sizeof(MLFloat16));
}
- NodeArg& node_arg = graph_utils::AddInitializerWithExternalData(graph, initializer);
+ NodeArg& node_arg = graph_utils::AddInitializer(graph, initializer);
modified = true;
return &node_arg;
}
diff --git a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc
index 388ab14dd51fe..e604c688ee033 100644
--- a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc
+++ b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc
@@ -137,12 +137,8 @@ static void FuseInitializerWithNode(Graph& graph,
graph.RemoveEdge(node.Index(), next_node.Index(), 0, static_cast(next_node_arg_index));
// Add the new converted Tensor in next node as initializer potentially with external data
- ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get(), new_arg_name, true);
- if (!utils::HasExternalData(dst_tensor)) {
- new_data = OrtValue(); // Data is inline
- }
-
- auto& new_arg = graph_utils::AddInitializerWithExternalData(graph, dst_tensor, std::move(new_data));
+ ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get(), new_arg_name, false);
+ auto& new_arg = graph_utils::AddInitializer(graph, dst_tensor);
graph_utils::ReplaceNodeInput(next_node, static_cast(next_node_arg_index), new_arg);
}
diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc
index 3cd06350df95d..bd730683a4c91 100644
--- a/onnxruntime/core/optimizer/gather_fusion.cc
+++ b/onnxruntime/core/optimizer/gather_fusion.cc
@@ -256,7 +256,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
axes_initializer_proto.add_dims(static_cast(1));
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
axes_initializer_proto.add_int64_data(axis);
- NodeArg* axes_arg = &graph_utils::AddInitializerWithExternalData(graph, axes_initializer_proto);
+ NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes",
{split_output_arg, axes_arg}, {original_output_arg});
@@ -272,7 +272,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
split_initializer_proto.add_dims(static_cast(split_values.size()));
split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());
- NodeArg* split_initializer_arg = &graph_utils::AddInitializerWithExternalData(graph, split_initializer_proto);
+ NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
const auto split_node_name = graph.GenerateNodeName(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion");
Node& split_node = graph.AddNode(split_node_name, "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs);
@@ -359,7 +359,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
unsqueeze_axes_initializer_proto.add_dims(static_cast(1));
unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
unsqueeze_axes_initializer_proto.add_int64_data(static_cast(0));
- NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, unsqueeze_axes_initializer_proto);
+ NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto);
for (size_t i = 0; i < range_input_defs.size(); ++i) {
Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze",
@@ -386,7 +386,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
} else {
slice_axes_initializer_proto.add_int32_data(static_cast(axis));
}
- NodeArg* slice_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, slice_axes_initializer_proto);
+ NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto);
Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes",
{gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1],
slice_axes_arg, unsqueeze_outputs[2]},
diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc
index 761fe1854274e..fed72db71332a 100644
--- a/onnxruntime/core/optimizer/matmul_add_fusion.cc
+++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc
@@ -194,7 +194,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
shape_initializer_proto.add_dims(static_cast(shape.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape.size() * sizeof(int64_t));
- NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto);
+ NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
ONNX_NAMESPACE::TypeProto new_arg_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast(
gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type());
diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc
index 725cb3fc33f04..367fb42d7928d 100644
--- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc
+++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc
@@ -212,14 +212,14 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
matmul_b.ToProto(new_gemm_b_tensor);
const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name());
new_gemm_b_tensor.set_name(new_gemm_b_name);
- NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_b_tensor);
+ NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor);
// create bias tensorProto for new Gemm node from initializer.
ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor;
bias.ToProto(new_gemm_bias_tensor);
const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias");
new_gemm_bias_tensor.set_name(new_gemm_bias_name);
- NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_bias_tensor);
+ NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor);
Node& gemm_node = graph.AddNode(
graph.GenerateNodeArgName("MatMulBnFusion_Gemm"),
diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc
index 335209dbfadaf..f094a48e10c33 100644
--- a/onnxruntime/core/optimizer/nchwc_transformer.cc
+++ b/onnxruntime/core/optimizer/nchwc_transformer.cc
@@ -437,7 +437,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]);
}
- nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto);
+ nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto);
filters_map->emplace(input_defs[1], nchwc_conv_W_arg);
}
@@ -464,7 +464,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
nchwc_conv_B_tensor_proto.add_dims(nchwc_output_channels);
- nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto);
+ nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto);
aligned_biases_.emplace(input_defs[2], nchwc_conv_B_arg);
}
}
@@ -580,7 +580,7 @@ Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg,
}
shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims);
- shape_arg = &graph_utils::AddInitializerWithExternalData(graph_, shape_tensor_proto);
+ shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto);
}
Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"),
@@ -892,7 +892,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
nchwc_conv_W_tensor_proto.add_dims(1);
nchwc_conv_W_tensor_proto.add_dims(1);
- auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto);
+ auto* nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto);
std::copy_n(bn_B.data(), channels, padded_buffer.data());
@@ -903,7 +903,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
gsl::narrow(nchwc_channels) * sizeof(float));
nchwc_conv_B_tensor_proto.add_dims(nchwc_channels);
- auto* nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto);
+ auto* nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto);
// Create the replacement node.
std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_bn_nchwc");
diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc
index 42cd31b5bd7b4..42d27de632b91 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc
@@ -130,22 +130,22 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log
weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8");
weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims());
utils::SetRawDataInTensorProto(weights_proto_u8, w_temp.data(), static_cast(w_temp.size()));
- input_defs[w_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8);
+ input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8);
ONNX_NAMESPACE::TensorProto weight_zp_proto_u8;
QDQ::Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true);
- input_defs[w_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8);
+ input_defs[w_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8);
ONNX_NAMESPACE::TensorProto r_proto_u8;
r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8");
r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims());
utils::SetRawDataInTensorProto(r_proto_u8, r_temp.data(), static_cast(r_temp.size()));
- input_defs[r_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_proto_u8);
+ input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8);
ONNX_NAMESPACE::TensorProto r_zp_proto_u8;
QDQ::Int8TensorProto2Uint8(r_zp_tensor_proto, r_zp_proto_u8, graph, true);
- input_defs[r_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_zp_proto_u8);
+ input_defs[r_zp_idx] = &graph_utils::AddInitializer(graph, r_zp_proto_u8);
return true;
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc
index 98c818b0c761b..828165e99d840 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc
@@ -61,7 +61,7 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) {
zp_tensor_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
zp_tensor_proto_u8.set_name(graph.GenerateNodeArgName("qdq_s8_to_u8_zp_conversion"));
utils::SetRawDataInTensorProto(zp_tensor_proto_u8, &q_zp_value, sizeof(uint8_t));
- NodeArg* zp_u8_arg = &graph_utils::AddInitializerWithExternalData(graph, zp_tensor_proto_u8);
+ NodeArg* zp_u8_arg = &graph_utils::AddInitializer(graph, zp_tensor_proto_u8);
auto q_output_node_arg_name = graph.GenerateNodeArgName("qdq_s8_to_u8_quant");
NodeArg* q_output_arg = &graph.GetOrCreateNodeArg(q_output_node_arg_name, nullptr);
diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc
index 616144c0ccde0..f094f3c199f2a 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc
@@ -43,12 +43,12 @@ bool ConvertS8WeightToU8(Graph& graph, Node& op_node,
// The weights fits into S7, overflow is not a problem, no need to convert to U8
return false;
}
- input_defs[weights_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8);
+ input_defs[weights_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8);
// Convert weight zero point to uint8
ONNX_NAMESPACE::TensorProto weight_zp_proto_u8;
Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true);
- input_defs[weight_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8);
+ input_defs[weight_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8);
return true;
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
index dce69e2913582..34d7ba3c79775 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
@@ -439,23 +439,23 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
}
}
- auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true);
- auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true);
+ auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, false);
+ auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, false);
std::optional zp_T_tp;
if (zp_dst) {
- zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true));
+ zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, false));
}
auto& input_defs = replacement_node.MutableInputDefs();
- input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, weight_T_tp, std::move(weight_dst)));
+ input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);
- input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, scale_T_tp, std::move(scale_dst)));
+ input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);
if (zp_T_tp) {
- input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, zp_T_tp.value(), std::move(*zp_dst)));
+ input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value()));
replacement_node.MutableInputArgsCount().push_back(1);
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc
index aa6f9c5409de7..8caa67f266266 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc
@@ -131,14 +131,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph
weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale"));
weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
weight_scale_proto.mutable_float_data()->Add(scale);
- weight_scale_arg = &graph_utils::AddInitializerWithExternalData(graph, weight_scale_proto);
+ weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto);
// Weight zero point initializer.
ONNX_NAMESPACE::TensorProto weight_zp_proto;
weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp"));
weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8);
weight_zp_proto.mutable_int32_data()->Add(static_cast(zp));
- NodeArg& weight_zp_arg = graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto);
+ NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto);
// Q from float32 to int8.
ONNX_NAMESPACE::TypeProto weight_q_type_proto;
diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc
index efd7022ab764b..07902fde04930 100644
--- a/onnxruntime/core/optimizer/relu_clip_fusion.cc
+++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc
@@ -97,7 +97,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
mutable_next_node->AddAttribute("min", 0.f);
} else {
// Add the initialized tensor to the graph
- auto* replacement_min_nodearg = &graph_utils::AddInitializerWithExternalData(graph, replacement_min);
+ auto* replacement_min_nodearg = &graph_utils::AddInitializer(graph, replacement_min);
// Replace the input def at the appropriate index of the Clip node
auto& mutable_input_defs = mutable_next_node->MutableInputDefs();
diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc
index 36213609f6b61..324905f953eec 100644
--- a/onnxruntime/core/optimizer/reshape_fusion.cc
+++ b/onnxruntime/core/optimizer/reshape_fusion.cc
@@ -438,7 +438,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
shape_initializer_proto.add_dims(static_cast(shape_value.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t));
- auto& new_node_arg = graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto);
+ auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto);
// Safely remove concat parent nodes which have only one output
for (int i = 0; i < concat_input_count; ++i) {
@@ -492,7 +492,7 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) {
shape_initializer_proto.add_dims(static_cast(shape_value.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t));
- NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto);
+ NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
{contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg},
{contiguous_reshapes.back().get().MutableOutputDefs()[0]});
diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc
index 74121508132dc..5c09e5225ab9c 100644
--- a/onnxruntime/core/optimizer/stft_decomposition.cc
+++ b/onnxruntime/core/optimizer/stft_decomposition.cc
@@ -46,7 +46,7 @@ NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[T
proto.add_dims(shape[i]);
}
utils::SetRawDataInTensorProto(proto, begin, element_count * sizeof(TDataType));
- return &graph_utils::AddInitializerWithExternalData(graph, proto);
+ return &graph_utils::AddInitializer(graph, proto);
}
template
diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc
index a320de2ee7a13..cc7682b2b418d 100644
--- a/onnxruntime/core/optimizer/transformer_memcpy.cc
+++ b/onnxruntime/core/optimizer/transformer_memcpy.cc
@@ -383,21 +383,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
TensorProto new_tensor_proto = *tensor_proto;
*(new_tensor_proto.mutable_name()) = new_def_name;
- // Query any OrtValue existing for the original initializer
- // We are checking outer scope because GetInitializer is called with true, therefore, we potentially
- // have references to parent graphs.
- // We are doing this so the same OrtValue is re-used in subgraphs and no copies made for big items.
- constexpr const bool check_outer_scope_true = true;
- OrtValue ort_value;
- // The initializer can be in memory with OrtValue or it can be a flatbuffer mapped.
- if (utils::HasExternalDataInMemory(new_tensor_proto) &&
- graph_.GetOrtValueInitializer(name, ort_value, check_outer_scope_true)) {
- // Re-use the same ort_value and proto that points to the same buffer
- ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializerWithExternalData(graph_, new_tensor_proto,
- std::move(ort_value)));
- } else {
- ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializer(graph_, new_tensor_proto));
- }
+ ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializer(graph_, new_tensor_proto));
replacements.insert(std::make_pair(provider_def, &new_def));
}
diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc
index 48ea54434b805..3a95d2a53e8f5 100644
--- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc
+++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc
@@ -586,10 +586,10 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector& shape) {
@@ -622,7 +622,7 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vector()->Reshape(new_shape);
- }
-
- auto& new_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_tensor_proto, ort_value);
+ auto& new_node_arg = graph_utils::AddInitializer(graph, new_tensor_proto);
graph_utils::ReplaceNodeWithInitializer(graph, node, new_node_arg);
// Remove the Unsqueeze node and replace it with the initializer.
diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
index e8d133779f33c..51a8b13cd8261 100644
--- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
+++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
@@ -734,6 +734,10 @@ struct CudaEpFactory : OrtEpFactory {
}
*/
+ // guard against bad device discovery. max devices we expect to add is num_cuda_devices. if we're attempting
+ // to add more than that we have duplicates in the `devices` array.
+ max_ep_devices = std::min(max_ep_devices, static_cast(num_cuda_devices));
+
int16_t device_id = 0;
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
index 286db9070766d..cc9d9f3da1d81 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
@@ -123,10 +123,11 @@ void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /
// even for empty tensors, so allocate a dummy byte.
size = std::max(size, static_cast(1));
if (size > allocated_size) {
- cudaFree(outputPtr);
+ alloc_->Free(alloc_, outputPtr);
outputPtr = nullptr;
allocated_size = 0;
- if (cudaMalloc(&outputPtr, size) == cudaSuccess) {
+ outputPtr = alloc_->Alloc(alloc_, size);
+ if (outputPtr) {
allocated_size = size;
}
}
@@ -352,193 +353,6 @@ bool ApplyProfileShapesFromProviderOptions(std::vector shape values" for the INT32 shape tensor input across this inference run
- * @param shape_tensor_values_int64 holds "shape tensor -> shape values" for the INT64 shape tensor input across this inference run
- */
-Status ApplyProfileShapesFromInputTensorValue(std::vector& trt_profiles,
- Ort::KernelContext ctx,
- nvinfer1::ITensor* input,
- ShapeRangesMap& shape_ranges,
- const std::unordered_map& input_indexes,
- std::unordered_map>& shape_tensor_values,
- std::unordered_map>& shape_tensor_values_int64,
- cudaStream_t stream,
- bool* engine_update) {
- for (size_t i = 0; i < trt_profiles.size(); i++) {
- const std::string& input_name = input->getName();
- nvinfer1::Dims dims = input->getDimensions();
- int nb_dims = dims.nbDims;
-
- size_t input_index = 0;
- const auto& iter = input_indexes.find(input_name);
- if (iter != input_indexes.end()) {
- input_index = iter->second;
- }
-
- auto input_tensor = ctx.GetInput(input_index);
- auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
- const auto tensor_shapes = tensor_info.GetShape();
- auto& shape_ranges_per_input = shape_ranges[input_name];
-
- auto trt_profile = trt_profiles[i];
-
- // If there are multiple profiles, for second and rest of profiles, simply copy the min/max/opt profile values from the first profile.
- // Following "if statement" won't be executed since TRT EP currently only allows single profile for non-explicit profiles case.
- if (i > 0) {
- if (input->isShapeTensor()) {
- // shape tensor
- int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]);
- std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
- for (int j = 0; j < shape_size; j++) {
- shapes_min[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN));
- shapes_max[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX));
- shapes_opt[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT));
- }
- trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
- trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
- trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
- } else {
- // execution tensor
- nvinfer1::Dims dims_min, dims_opt, dims_max;
- dims_min = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN);
- dims_max = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX);
- dims_opt = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT);
- trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
- trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
- trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
- }
- continue;
- }
-
- // Create shape profile
- if (input->isShapeTensor()) {
- // Get shape values for shape tensor input
- const auto tensor_type = tensor_info.GetElementType();
- // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension
- int shape_size = dims.nbDims == 0 ? 1 : static_cast(tensor_shapes[0]);
- // For setting TRT optimization profile. (Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10)
- std::vector values(shape_size);
-
- switch (tensor_type) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
- auto buffer = std::make_unique(shape_size);
- auto status = GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
- }
- shape_tensor_values[input_name].resize(shape_size);
- for (int j = 0; j < shape_size; ++j) {
- shape_tensor_values[input_name][j] = buffer[j];
- values[j] = buffer[j];
- }
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
- auto buffer = std::make_unique(shape_size);
- auto status = GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
- }
- shape_tensor_values_int64[input_name].resize(shape_size);
- for (int j = 0; j < shape_size; ++j) {
- shape_tensor_values_int64[input_name][j] = buffer[j];
- values[j] = static_cast(buffer[j]);
- }
- break;
- }
- default: {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.");
- }
- }
-
- // Update shape ranges
- std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
- int shape_range_size = static_cast(shape_ranges_per_input.size());
- if (shape_size == shape_range_size) {
- // If shape size matches, check/update shape range
- for (int j = 0; j < shape_size; ++j) {
- auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile
- shapes_min[j] = static_cast(shape_range[0]);
- shapes_max[j] = static_cast(shape_range[1]);
- shapes_opt[j] = static_cast(shape_range[2]);
-
- const auto& tensor_shape_value = values[j];
- // Update shape range lower bound
- if (tensor_shape_value < shape_range[0]) {
- shape_range[0] = tensor_shape_value;
- shapes_min[j] = tensor_shape_value;
- *engine_update = true;
- }
- // Update shape range upper bound
- if (tensor_shape_value > shape_range[1]) {
- shape_range[1] = tensor_shape_value;
- shape_range[2] = tensor_shape_value;
- shapes_max[j] = tensor_shape_value;
- shapes_opt[j] = tensor_shape_value;
- *engine_update = true;
- }
- }
- } else {
- // If shape size doesn't match, initialize shape_range with the new shape value
- shape_ranges_per_input.clear();
- for (int j = 0; j < shape_size; ++j) {
- const auto& tensor_shape_value = values[j];
- std::vector> profile_vector;
- std::vector shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value};
- profile_vector.push_back(shape_vector); // only one profile needed
- shape_ranges_per_input[j] = profile_vector;
- shapes_min[j] = tensor_shape_value;
- shapes_opt[j] = tensor_shape_value;
- shapes_max[j] = tensor_shape_value;
- }
- *engine_update = true;
- }
-
- trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
- trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
- trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
- } else { // Execution tensor
- nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);
- for (int j = 0, end = nb_dims; j < end; ++j) {
- const auto& tensor_shape = tensor_shapes[j];
- if (shape_ranges_per_input.find(j) != shape_ranges_per_input.end()) {
- auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile
- dims_min.d[j] = static_cast(shape_range[0]);
- dims_max.d[j] = static_cast(shape_range[1]);
- dims_opt.d[j] = static_cast(shape_range[2]);
-
- // Update minimum dimension
- if (tensor_shape < shape_range[0]) {
- shape_range[0] = tensor_shape;
- dims_min.d[j] = static_cast(tensor_shape);
- *engine_update = true;
- }
- // Update maximum dimension
- if (tensor_shape > shape_range[1]) {
- shape_range[1] = tensor_shape;
- shape_range[2] = tensor_shape;
- dims_max.d[j] = static_cast(tensor_shape);
- dims_opt.d[j] = static_cast(tensor_shape);
- *engine_update = true;
- }
- }
- }
-
- trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
- trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
- trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
- }
- }
- return Status::OK();
-}
-
#define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \
case DATA_TYPE: { \
auto input_tensor_ptr = input_tensor.GetTensorData(); \
@@ -554,6 +368,7 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \
+ skip_input_binding_allowed = false; \
if (input_tensor_ptr != nullptr && elem_cnt > 0) { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \
data = scratch_buffers.back().get(); \
@@ -568,6 +383,7 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \
+ data_ptr = output_tensor_ptr; \
if (output_tensor_ptr != nullptr && elem_cnt > 0) { \
buffers[output_name] = output_tensor_ptr; \
} else { \
@@ -580,6 +396,8 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \
+ data_ptr = output_tensor_ptr; \
+ skip_output_binding_allowed = false; \
if (output_tensor_ptr != nullptr && elem_cnt > 0) { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \
buffers[output_name] = scratch_buffers.back().get(); \
@@ -628,7 +446,8 @@ Status BindContextInput(Ort::KernelContext& ctx,
std::unordered_map>& shape_tensor_values_int64,
std::vector>& scratch_buffers,
OrtAllocator* alloc,
- cudaStream_t stream) {
+ cudaStream_t stream,
+ bool& skip_input_binding_allowed) {
auto input_tensor = ctx.GetInput(input_index);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shapes = tensor_info.GetShape();
@@ -647,7 +466,7 @@ Status BindContextInput(Ort::KernelContext& ctx,
if (trt_engine->isShapeInferenceIO(input_name)) {
// Bind "shape tensor" input buffer
-
+ skip_input_binding_allowed = false; // Shape tensor input binding cannot be skipped
// The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension
int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast(tensor_shapes[0]);
switch (tensor_type) {
@@ -775,19 +594,20 @@ Status BindContextOutput(Ort::KernelContext& ctx,
DDSOutputAllocatorMap& dds_output_allocator_map,
std::vector>& scratch_buffers,
OrtAllocator* alloc,
- std::unordered_map& buffers) {
+ std::unordered_map& buffers,
+ nvinfer1::Dims& dims,
+ void*& data_ptr,
+ bool& skip_output_binding_allowed) {
// Get output shape
- nvinfer1::Dims dims = trt_context->getTensorShape(output_name);
+ dims = trt_context->getTensorShape(output_name);
int nb_dims = dims.nbDims;
bool is_DDS = false;
- std::vector output_shapes(nb_dims);
for (int j = 0, end = nb_dims; j < end; ++j) {
// data-dependent shape
if (dims.d[j] == -1) {
is_DDS = true;
break;
}
- output_shapes[j] = dims.d[j];
}
auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end();
@@ -800,16 +620,19 @@ Status BindContextOutput(Ort::KernelContext& ctx,
// Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3.
if (is_DDS || known_DDS) {
if (!known_DDS) {
- auto allocatorPtr = std::make_unique();
+ auto allocatorPtr = std::make_unique(alloc);
trt_context->setOutputAllocator(output_name, allocatorPtr.get());
dds_output_allocator_map[output_name] = std::move(allocatorPtr);
+ dims.nbDims = -1; // Set to -1 to indicate that the shape is not known at this point.
+ data_ptr = nullptr; // Set data_ptr to nullptr for DDS output binding.
}
} else {
- output_tensors[i] = ctx.GetOutput(output_index, output_shapes);
+ output_tensors[i] = ctx.GetOutput(output_index, dims.d, nb_dims);
auto& output_tensor = output_tensors[i];
const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount();
switch (output_type) {
+ // below macros set data_ptr and skip_output_binding_allowed variables
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
@@ -840,7 +663,6 @@ Status BindContextOutput(Ort::KernelContext& ctx,
* we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support.
*/
Status BindKernelOutput(Ort::KernelContext& ctx,
- OrtMemoryInfo* /*mem_info*/,
DDSOutputAllocatorMap& allocator_map,
char const* output_name,
size_t output_index,
@@ -903,31 +725,6 @@ NvExecutionProvider::PerThreadContext::~PerThreadContext() {
trt_context_map_.clear();
}
-/*
- * Returns true if the shape ranges maintained by the PerThreadContext is different from the shape ragnes maintained by TRT EP, meaning the
- * engine is being updated and the execution context maintained by the PerThreadContext should be updated as well. Otherwise, returns false.
- *
- */
-bool NvExecutionProvider::PerThreadContext::CompareProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges) {
- if (shape_ranges.size() > 0) {
- if (input_shape_ranges_[fused_node] != shape_ranges) {
- LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] The shape ranges maintained by the PerThreadContext is different from the shape ranges maintained by TRT EP. \
- This means the engine is updated and will need to update the execution context as well.";
- return true;
- }
- }
- return false;
-}
-
-/*
- * Updates the shape ranges maintained by the PerThreadContext.
- * As long as the execution context maintained by the PerThreadContext is updated, the associated shape ranges should be updated as well.
- *
- */
-void NvExecutionProvider::PerThreadContext::UpdateProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges) {
- input_shape_ranges_[fused_node] = shape_ranges;
-}
-
void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fused_node) {
auto it = trt_context_map_.find(fused_node);
if (it != trt_context_map_.end()) {
@@ -1081,7 +878,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
engine_decryption_lib_path_ = info.engine_decryption_lib_path;
}
force_sequential_engine_build_ = info.force_sequential_engine_build;
- context_memory_sharing_enable_ = info.context_memory_sharing_enable;
sparsity_enable_ = info.sparsity_enable;
auxiliary_streams_ = info.auxiliary_streams;
profile_min_shapes = info.profile_min_shapes;
@@ -1225,7 +1021,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
<< ", nv_engine_decryption_enable: " << engine_decryption_enable_
<< ", nv_engine_decryption_lib_path: " << engine_decryption_lib_path_
<< ", nv_force_sequential_engine_build: " << force_sequential_engine_build_
- << ", nv_context_memory_sharing_enable: " << context_memory_sharing_enable_
<< ", nv_sparsity_enable: " << sparsity_enable_
<< ", nv_auxiliary_streams: " << auxiliary_streams_
<< ", nv_cuda_graph_enable: " << cuda_graph_enable_
@@ -1298,9 +1093,15 @@ void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
}
std::vector NvExecutionProvider::CreatePreferredAllocators() {
+ OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kSameAsRequested),
+ -1, -1, -1, -1);
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); },
- narrow(device_id_));
+ narrow(device_id_),
+ true,
+ arena_cfg,
+ // make it stream aware
+ true);
AllocatorCreationInfo pinned_allocator_info(
[](OrtDevice::DeviceId device_id) {
@@ -2244,6 +2045,96 @@ common::Status NvExecutionProvider::Compile(const std::vector
return Status::OK();
}
+/**
+ * @brief Determines whether I/O binding is required for TensorRT execution.
+ *
+ * This function optimizes TensorRT inference performance by determining when tensor
+ * input/output binding operations can be skipped. Binding is an expensive operation
+ * that involves setting up tensor pointers in the TensorRT execution context, so
+ * avoiding unnecessary rebinding can significantly improve inference throughput.
+ *
+ * The function implements a three-tier decision logic:
+ * 1. First run: Always requires binding to establish initial tensor mappings
+ * 2. Subsequent runs with optimization allowed: Only rebind if tensors have changed
+ * 3. Subsequent runs without optimization: Always rebind for safety
+ *
+ * @tparam TRTState The TensorRT state type (TensorrtFuncState or TensorrtShortFuncState)
+ * @param trt_state Pointer to the TensorRT execution state containing tensor cache
+ * and configuration flags
+ * @param ctx ONNX Runtime kernel context providing access to current input tensors
+ *
+ * @return true if I/O binding is required (tensors changed or safety conditions apply),
+ * false if binding can be safely skipped (optimization enabled and tensors unchanged)
+ *
+ * @note This function modifies trt_state by:
+ * - Setting is_first_run to false after first execution
+ * - Caching current tensor parameters in input_tensors vector
+ * - Updating cached tensors when changes are detected
+ *
+ * @warning The skip_io_binding_allowed flag must be carefully managed as incorrect
+ * usage can lead to inference with stale tensor bindings and incorrect results.
+ */
+template
+static bool IsIOBindingRequired(TRTState* const trt_state, const Ort::KernelContext& ctx) {
+ // Check if input tensors have changed since the last run
+ // If so, we need to bind input tensors again
+ bool require_io_binding = false;
+
+ if (trt_state->is_first_run) {
+ // If this is the first run, we always bind input tensors
+ require_io_binding = true;
+ auto input_tensor_count = ctx.GetInputCount();
+ auto output_tensor_count = ctx.GetOutputCount();
+ trt_state->input_tensors.resize(input_tensor_count);
+ trt_state->output_tensors.resize(output_tensor_count);
+ for (size_t input_index = 0; input_index < input_tensor_count; ++input_index) {
+ const auto& input_tensor = ctx.GetInput(input_index);
+ const auto& tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
+
+ trt_state->input_tensors[input_index] = TensorParams{input_tensor.GetTensorRawData(), tensor_info.GetShape()};
+ }
+ trt_state->is_first_run = false;
+ } else if (trt_state->skip_io_binding_allowed) {
+ // If skip_io_binding_allowed is true, we can skip binding if input tensors are the same as before
+ auto input_tensor_count = ctx.GetInputCount();
+ for (size_t input_index = 0; input_index < input_tensor_count; ++input_index) {
+ const auto& input_tensor = ctx.GetInput(input_index);
+ const auto& tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
+
+ TensorParams ip_tensor{input_tensor.GetTensorRawData(), tensor_info.GetShape()};
+
+ if (ip_tensor != trt_state->input_tensors[input_index]) {
+ require_io_binding = true;
+ trt_state->input_tensors[input_index] = ip_tensor;
+ }
+ }
+ } else {
+ // If this is not the first run and skip_io_binding_allowed is false, we need to bind input tensors
+ require_io_binding = true;
+ }
+
+ if (!require_io_binding) {
+ // no need to bind inputs, but check outputs as well
+ auto output_tensor_count = ctx.GetOutputCount();
+
+ for (size_t output_index = 0; output_index < output_tensor_count; ++output_index) {
+ const auto& prev_output_tensor = trt_state->output_tensors[output_index];
+
+ if (prev_output_tensor.dims.nbDims != -1) {
+ const auto& new_output_tensor = ctx.GetOutput(output_index, prev_output_tensor.dims.d, prev_output_tensor.dims.nbDims);
+
+ // different output tensor data means we need to bind outputs again
+ if (prev_output_tensor.data != new_output_tensor.GetTensorRawData()) {
+ require_io_binding = true;
+ break;
+ }
+ }
+ }
+ }
+
+ return require_io_binding;
+}
+
Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer,
const Node& fused_node,
std::unordered_map& input_map,
@@ -2349,21 +2240,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
ShapeRangesMap input_explicit_shape_ranges;
ShapeRangesMap input_implicit_shape_ranges;
- auto tensor_is_dynamic = [&](nvinfer1::ITensor* tensor) -> bool {
- if (tensor->isShapeTensor()) {
- return true;
- } else {
- nvinfer1::Dims dims = tensor->getDimensions();
- // Execution tensor
- for (int j = 0, end = dims.nbDims; j < end; ++j) {
- if (dims.d[j] == -1) {
- return true;
- }
- }
- }
- return false;
- };
-
bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false
if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) {
has_explicit_profile = true;
@@ -2375,7 +2251,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
} else {
for (unsigned int i = 0, end = num_inputs; i < end; ++i) {
auto input = trt_network->getInput(i);
- has_dynamic_shape |= tensor_is_dynamic(input);
+ has_dynamic_shape |= checkTrtTensorIsDynamic(input);
}
if (has_dynamic_shape) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] No explicit optimization profile was specified. "
@@ -2574,31 +2450,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
// Build context
// Note: Creating an execution context from an engine is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
- if (context_memory_sharing_enable_) {
-#if defined(_MSC_VER)
-#pragma warning(push)
-#pragma warning(disable : 4996)
-#endif
- size_t mem_size = trt_engine->getDeviceMemorySizeV2();
-#if defined(_MSC_VER)
-#pragma warning(pop)
-#endif
- if (mem_size > max_ctx_mem_size_) {
- max_ctx_mem_size_ = mem_size;
- }
- trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
- } else {
- trt_context = std::unique_ptr(trt_engine->createExecutionContext());
- }
+ trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
if (!trt_context) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"Nv EP could not build execution context for fused node: " + fused_node.Name());
}
+ bool is_dynamic_shape_context = false;
// Create input to index map
for (int i = 0; i < num_inputs; ++i) {
auto input = trt_network->getInput(i);
const std::string& input_name = input->getName();
+ is_dynamic_shape_context |= checkTrtDimIsDynamic(trt_engine->getTensorShape(input_name.c_str()));
const auto& iter = input_map.find(input_name);
if (iter != input_map.end()) {
input_indexes[input_name] = iter->second;
@@ -2639,10 +2502,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
input_shape_ranges_[context->node_name], &tensorrt_mu_, trt_node_name_with_precision,
engine_cache_enable_, cache_path_,
runtime_.get(), profiles_[context->node_name],
- context_memory_sharing_enable_, &max_ctx_mem_size_,
engine_decryption_enable_, engine_decryption_, engine_encryption_,
detailed_build_log_, sparsity_enable_,
- auxiliary_streams_, cuda_graph_enable_, cache_prefix_, cache_suffix};
+ auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_, cache_suffix};
*state = p.release();
return 0;
};
@@ -2666,25 +2528,20 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
const std::unordered_map& output_indexes = (trt_state->output_info)[0];
const std::unordered_map& output_types = (trt_state->output_info)[1];
auto fused_node_name = trt_state->fused_node_name;
- // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles.
- // The info is used for both shape tensor and execution tensor:
- // tensor name->(dimension->[min, max, opt])
- auto& shape_ranges = trt_state->input_shape_ranges;
+
std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run
std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input
auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
auto trt_engine = trt_state->engine->get();
auto trt_context = trt_state->context->get();
auto trt_profiles = trt_state->profiles;
- auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr;
- int num_inputs = static_cast(input_indexes.size());
int num_outputs = static_cast(output_indexes.size());
std::unordered_set input_names;
- OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
- narrow(device_id_));
- OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device);
if (alloc_ == nullptr) {
+ OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
+ narrow(device_id_));
+ OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device);
Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
}
OrtAllocator* alloc = alloc_;
@@ -2698,68 +2555,13 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP select an optimization profile for the current context failed");
}
- // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
- // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
- // Prepare cache name
- std::string cache_path = "";
- // Customize cache prefix if assigned
- if (!cache_prefix_.empty()) {
- cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix;
- } else {
- cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
- }
-
- // Enable hardware compatility mode if assigned
- std::string cache_hw_compat = "_sm" + compute_capability_;
-
- // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
- // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
- const std::string cache_path_prefix = cache_path + cache_hw_compat;
- std::string engine_cache_path = cache_path_prefix + ".engine";
- const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
- const std::string profile_cache_path = cache_path_prefix + ".profile";
-
- // If weight-stripped engine is enabled and refitted engine cache is not present,
- // TRT EP will use the engine cache with ".stripped.engine" appended to the end.
- const std::filesystem::path engine_cache_fs_path = engine_cache_path;
- if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) {
- engine_cache_path = cache_path_prefix + ".stripped.engine";
- weight_stripped_engine_refit_ = true;
- }
-
- // Check and update shape ranges for dynamic shape inputs.
- for (int i = 0, end = num_inputs; i < end; ++i) {
- auto input = trt_state->network->get()->getInput(i);
- const std::string& input_name = input->getName();
- input_names.insert(input_name);
-
- // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved.
- // TRT EP will help determine the min/max/opt profile values based on current input tensor value.
- if (shape_ranges.find(input_name) != shape_ranges.end()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP failed to parse input tensor and generate optimization profiles.");
- }
- }
-
- if (weight_stripped_engine_refit_) {
- auto status = RefitEngine(model_path_,
- onnx_model_folder_path_,
- engine_cache_path,
- false /* path check for security */,
- onnx_model_bytestream_,
- onnx_model_bytestream_size_,
- trt_engine,
- false /* serialize refitted engine to disk */,
- detailed_build_log_);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
- }
- }
-
// Check before using trt_engine
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found.");
}
+ bool require_io_binding = IsIOBindingRequired(trt_state, ctx);
+
// Get input and output binding names
int total_bindings = trt_engine->getNbIOTensors();
std::vector input_binding_names, output_binding_names;
@@ -2776,23 +2578,25 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
/*
* Set input shapes and bind input buffers
*/
- std::vector> scratch_buffers;
- for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
- char const* input_name = input_binding_names[i];
-
- size_t input_index = 0;
- const auto iter = input_indexes.find(input_name);
- if (iter != input_indexes.end()) {
- input_index = iter->second;
- }
- auto input_tensor = ctx.GetInput(input_index);
- auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
- const auto tensor_shapes = tensor_info.GetShape();
+ auto& scratch_buffers = trt_state->scratch_buffers;
+ if (require_io_binding) {
+ scratch_buffers.clear();
+ bool skip_input_binding_allowed = true;
+ for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
+ char const* input_name = input_binding_names[i];
+
+ size_t input_index = 0;
+ const auto iter = input_indexes.find(input_name);
+ if (iter != input_indexes.end()) {
+ input_index = iter->second;
+ }
- auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream, skip_input_binding_allowed);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
}
+ trt_state->skip_io_binding_allowed = skip_input_binding_allowed;
}
/*
@@ -2806,44 +2610,51 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
std::unordered_map output_dim_sizes;
output_dim_sizes.reserve(num_outputs);
- for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
- char const* output_name = output_binding_names[i];
+ if (require_io_binding) {
+ bool skip_output_binding_allowed = true;
+ for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
+ char const* output_name = output_binding_names[i];
- size_t output_index = 0;
- const auto& index_iter = output_indexes.find(output_name);
- if (index_iter != output_indexes.end()) {
- output_index = index_iter->second;
- }
+ size_t output_index = 0;
+ const auto& index_iter = output_indexes.find(output_name);
+ if (index_iter != output_indexes.end()) {
+ output_index = index_iter->second;
+ }
- size_t output_type = 0;
- const auto type_iter = output_types.find(output_name);
- if (type_iter != output_types.end()) {
- output_type = type_iter->second;
- }
+ size_t output_type = 0;
+ const auto type_iter = output_types.find(output_name);
+ if (type_iter != output_types.end()) {
+ output_type = type_iter->second;
+ }
+
+ nvinfer1::Dims dims;
+ void* data_ptr = nullptr;
+
+ Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
+ dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
- Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
- dds_output_allocator_map, scratch_buffers, alloc, buffers);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims};
}
+
+ trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed;
}
// Set execution context memory
- if (trt_state->context_memory_sharing_enable) {
-#if defined(_MSC_VER)
-#pragma warning(push)
-#pragma warning(disable : 4996)
-#endif
+ if (require_io_binding) {
size_t mem_size = trt_engine->getDeviceMemorySizeV2();
-#if defined(_MSC_VER)
-#pragma warning(pop)
-#endif
- if (mem_size > *max_context_mem_size_ptr) {
- *max_context_mem_size_ptr = mem_size;
+ if (trt_state->is_dynamic_shape) {
+ mem_size = trt_context->updateDeviceMemorySizeForShapes();
+ }
+ if (trt_state->context_memory_size != mem_size) {
+ LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size;
+ trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/);
+ trt_state->context_memory_size = mem_size;
+ trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size);
}
- trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get());
}
-
// Start CUDA graph capture.
// Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
// current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
@@ -2894,7 +2705,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
- auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream);
+ auto status = BindKernelOutput(ctx, dds_output_allocator_map, output_name, output_index, output_type, stream);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
@@ -2961,33 +2772,19 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
//
// Note: Creating an execution context from an engine is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
- if (context_memory_sharing_enable_) {
-#if defined(_MSC_VER)
-#pragma warning(push)
-#pragma warning(disable : 4996)
-#endif
- size_t mem_size = trt_engine->getDeviceMemorySizeV2();
-#if defined(_MSC_VER)
-#pragma warning(pop)
-#endif
- if (mem_size > max_ctx_mem_size_) {
- max_ctx_mem_size_ = mem_size;
- }
- trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
-
- } else {
- trt_context = std::unique_ptr(trt_engine->createExecutionContext());
- }
+ trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
if (!trt_context) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"Nv EP could not build execution context for fused node: " + fused_node.Name());
}
+ bool is_dynamic_shape_context = false;
// Create input/output to index maps
for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) {
auto const& name = trt_engine->getIOTensorName(i);
auto const& mode = trt_engine->getTensorIOMode(name);
if (mode == nvinfer1::TensorIOMode::kINPUT) {
+ is_dynamic_shape_context |= checkTrtDimIsDynamic(trt_engine->getTensorShape(name));
const auto& iter = input_map.find(name);
if (iter != input_map.end()) {
input_indexes[name] = iter->second;
@@ -3027,9 +2824,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
&contexts_[context->node_name],
input_info_[context->node_name],
output_info_[context->node_name],
- context_memory_sharing_enable_,
- &max_ctx_mem_size_,
- &tensorrt_mu_};
+ &tensorrt_mu_,
+ is_dynamic_shape_context};
*state = p.release();
return 0;
};
@@ -3056,15 +2852,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
auto trt_engine = trt_state->engine->get();
auto trt_context = trt_state->context->get();
- auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr;
int num_outputs = static_cast(output_indexes.size());
std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run
std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input
- OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
- narrow(device_id_));
- OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device);
if (alloc_ == nullptr) {
+ OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
+ narrow(device_id_));
+ OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device);
Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
}
OrtAllocator* alloc = alloc_;
@@ -3078,6 +2873,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found.");
}
+ bool require_io_binding = IsIOBindingRequired(trt_state, ctx);
+
// Get input and output binding names
int total_bindings = trt_engine->getNbIOTensors();
std::vector input_binding_names, output_binding_names;
@@ -3094,20 +2891,25 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
/*
* Set input shapes and bind input buffers
*/
- std::vector> scratch_buffers;
- for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
- char const* input_name = input_binding_names[i];
-
- size_t input_index = 0;
- const auto iter = input_indexes.find(input_name);
- if (iter != input_indexes.end()) {
- input_index = iter->second;
- }
+ auto& scratch_buffers = trt_state->scratch_buffers;
+ if (require_io_binding) {
+ scratch_buffers.clear();
+ bool skip_input_binding_allowed = true;
+ for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
+ char const* input_name = input_binding_names[i];
+
+ size_t input_index = 0;
+ const auto iter = input_indexes.find(input_name);
+ if (iter != input_indexes.end()) {
+ input_index = iter->second;
+ }
- Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream, skip_input_binding_allowed);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
}
+ trt_state->skip_io_binding_allowed = skip_input_binding_allowed;
}
/*
@@ -3121,44 +2923,52 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
std::unordered_map output_dim_sizes;
output_dim_sizes.reserve(num_outputs);
- for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
- char const* output_name = output_binding_names[i];
+ if (require_io_binding) {
+ bool skip_output_binding_allowed = true;
+ for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
+ char const* output_name = output_binding_names[i];
- size_t output_index = 0;
- const auto& index_iter = output_indexes.find(output_name);
- if (index_iter != output_indexes.end()) {
- output_index = index_iter->second;
- }
+ size_t output_index = 0;
+ const auto& index_iter = output_indexes.find(output_name);
+ if (index_iter != output_indexes.end()) {
+ output_index = index_iter->second;
+ }
- size_t output_type = 0;
- const auto type_iter = output_types.find(output_name);
- if (type_iter != output_types.end()) {
- output_type = type_iter->second;
- }
+ size_t output_type = 0;
+ const auto type_iter = output_types.find(output_name);
+ if (type_iter != output_types.end()) {
+ output_type = type_iter->second;
+ }
+
+ nvinfer1::Dims dims;
+ void* data_ptr = nullptr;
+
+ Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
+ dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
- Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
- dds_output_allocator_map, scratch_buffers, alloc, buffers);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims};
}
+
+ trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed;
}
// Set execution context memory
- if (trt_state->context_memory_sharing_enable) {
-#if defined(_MSC_VER)
-#pragma warning(push)
-#pragma warning(disable : 4996)
-#endif
+ if (require_io_binding) {
size_t mem_size = trt_engine->getDeviceMemorySizeV2();
-#if defined(_MSC_VER)
-#pragma warning(pop)
-#endif
- if (mem_size > *max_context_mem_size_ptr) {
- *max_context_mem_size_ptr = mem_size;
+ if (trt_state->is_dynamic_shape) {
+ mem_size = trt_context->updateDeviceMemorySizeForShapes();
+ }
+ if (trt_state->context_memory_size != mem_size) {
+ LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size;
+ trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/);
+ // trt_state->context_memory = IAllocator::MakeUniquePtr(alloc, mem_size, false /*use_reserve*/, stream);
+ trt_state->context_memory_size = mem_size;
+ trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size);
}
- trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get());
}
-
// Start CUDA graph capture.
// Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
// current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
@@ -3209,7 +3019,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
- auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream);
+ auto status = BindKernelOutput(ctx, dds_output_allocator_map, output_name, output_index, output_type, stream);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
index 7a0c47d28c81d..83b89a2e9d1fb 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
@@ -78,6 +78,9 @@ using unique_pointer = std::unique_ptr;
//
class OutputAllocator : public nvinfer1::IOutputAllocator {
public:
+ OutputAllocator() = delete;
+ OutputAllocator(OrtAllocator* allocator) : alloc_(allocator) {};
+
void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override;
void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override;
@@ -95,10 +98,11 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
}
~OutputAllocator() override {
- cudaFree(outputPtr);
+ alloc_->Free(alloc_, outputPtr);
}
private:
+ OrtAllocator* alloc_;
void* outputPtr{nullptr};
uint64_t allocated_size = 0;
std::vector output_shapes;
@@ -110,6 +114,45 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
*/
using ShapeRangesMap = std::unordered_map>>>;
+/**
+ * @brief Container for tensor data and their shape.
+ *
+ */
+struct TensorParams {
+ const void* data{nullptr};
+ nvinfer1::Dims dims;
+
+ TensorParams() = default;
+
+ TensorParams(const void* data_ptr, const std::vector& shape) {
+ // Initialize data and dims from the Ort::ConstValue
+ data = data_ptr;
+
+ dims.nbDims = static_cast(shape.size());
+ for (int i = 0; i < dims.nbDims; ++i) {
+ dims.d[i] = static_cast(shape[i]);
+ }
+ }
+
+ TensorParams(const void* data_ptr, nvinfer1::Dims& shape) {
+ // Initialize data and dims from the Ort::ConstValue
+ data = data_ptr;
+
+ dims = shape;
+ }
+
+ bool operator!=(const TensorParams& other) const {
+ if (data != other.data || dims.nbDims != other.dims.nbDims)
+ return true;
+
+ for (int i = 0; i < dims.nbDims; ++i) {
+ if (dims.d[i] != other.dims.d[i])
+ return true;
+ }
+ return false;
+ }
+};
+
// Information to construct kernel function state.
struct TensorrtFuncState {
AllocateFunc test_allocate_func = nullptr;
@@ -130,8 +173,6 @@ struct TensorrtFuncState {
std::string engine_cache_path;
nvinfer1::IRuntime* runtime = nullptr;
std::vector profiles;
- bool context_memory_sharing_enable = false;
- size_t* max_context_mem_size_ptr = nullptr;
bool engine_decryption_enable = false;
int (*engine_decryption)(const char*, char*, size_t*) = nullptr;
int (*engine_encryption)(const char*, char*, size_t) = nullptr;
@@ -139,8 +180,16 @@ struct TensorrtFuncState {
bool sparsity_enable = false;
int auxiliary_streams = -1;
bool cuda_graph_enable = 0;
+ bool is_dynamic_shape = false;
std::string cache_prefix;
std::string cache_suffix;
+ std::vector> scratch_buffers;
+ std::vector input_tensors;
+ std::vector output_tensors;
+ bool is_first_run = true; // Indicates if this is the first run of the engine
+ bool skip_io_binding_allowed = false; // Indicates if input/output binding can be skipped
+ IAllocatorUniquePtr context_memory = nullptr;
+ size_t context_memory_size = 0;
};
// Minimum information to construct kernel function state for direct engine load code path
@@ -153,9 +202,15 @@ struct TensorrtShortFuncState {
std::unique_ptr* context = nullptr;
std::vector> input_info;
std::vector> output_info;
- bool context_memory_sharing_enable = false;
- size_t* max_context_mem_size_ptr = nullptr;
std::mutex* tensorrt_mu_ptr = nullptr;
+ bool is_dynamic_shape = false;
+ std::vector> scratch_buffers;
+ std::vector input_tensors;
+ std::vector output_tensors;
+ bool is_first_run = true; // Indicates if this is the first run of the engine
+ bool skip_io_binding_allowed = false; // Indicates if input/output binding can be skipped
+ IAllocatorUniquePtr context_memory = nullptr;
+ size_t context_memory_size = 0;
};
// Holds important information for building valid ORT graph.
@@ -251,9 +306,7 @@ class NvExecutionProvider : public IExecutionProvider {
std::mutex tensorrt_mu_;
int device_id_;
std::string compute_capability_;
- bool context_memory_sharing_enable_ = false;
size_t max_ctx_mem_size_ = 0;
- IAllocatorUniquePtr context_memory_ = nullptr;
mutable char model_path_[4096] = {}; // Reserved for max path length
bool engine_decryption_enable_ = false;
int (*engine_decryption_)(const char*, char*, size_t*) = nullptr;
@@ -341,8 +394,6 @@ class NvExecutionProvider : public IExecutionProvider {
nvinfer1::IExecutionContext& GetTensorRTContext(std::string fused_node);
bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr context);
void ResetTensorRTContext(std::string fused_node);
- bool CompareProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges);
- void UpdateProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges);
void InitCUDAGraph();
void SetGraphStream(cudaStream_t stream);
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h
index 2a67f3c3bec4d..4d6c6fe116076 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h
@@ -34,7 +34,6 @@ struct NvExecutionProviderInfo {
bool engine_decryption_enable{false};
std::string engine_decryption_lib_path{""};
bool force_sequential_engine_build{false};
- bool context_memory_sharing_enable{false};
std::string timing_cache_path{""};
bool detailed_build_log{false};
bool sparsity_enable{false};
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h
index 22e5eea6924de..ea586ba445ba2 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h
@@ -683,4 +683,29 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string
}
return "";
}
+
+/*
+ * Checks if there is a an element with value `-1` in nvinfer1::Dims
+ */
+static bool checkTrtDimIsDynamic(nvinfer1::Dims dims) {
+ for (int j = 0, end = dims.nbDims; j < end; ++j) {
+ if (dims.d[j] == -1) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/*
+ * Checks if an nvinfer1::ITensor signales a dynamic shape,
+ * either due to dynamic shapes or due to it being a shape tensor
+ */
+static bool checkTrtTensorIsDynamic(nvinfer1::ITensor* tensor) {
+ if (tensor->isShapeTensor()) {
+ return true;
+ } else {
+ // Execution tensor
+ return checkTrtDimIsDynamic(tensor->getDimensions());
+ }
+}
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
index e236cccaaaa77..d23d50549b2c5 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
@@ -557,6 +557,67 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
return ORT_VERSION;
}
+ /**
+ * @brief Checks if a given OrtHardwareDevice is a supported NVIDIA GPU.
+ *
+ * This function verifies if the provided hardware device corresponds to a physical
+ * NVIDIA GPU that meets the minimum compute capability requirements for this execution provider.
+ *
+ * The check is performed by:
+ * 1. Extracting the LUID (Locally Unique Identifier) from the device's metadata.
+ * 2. Converting the string LUID to a 64-bit integer.
+ * 3. Iterating through all available CUDA devices on the system.
+ * 4. For each CUDA device, constructing its 64-bit LUID from its properties.
+ * 5. Comparing the LUIDs. If a match is found, it checks if the device's
+ * compute capability is at least 8.0 (Ampere) or newer.
+ *
+ * @param device The OrtHardwareDevice to check.
+ * @return True if the device is a supported NVIDIA GPU, false otherwise.
+ */
+ bool IsOrtHardwareDeviceSupported(const OrtHardwareDevice& device) {
+ const auto& metadata_entries = device.metadata.Entries();
+ const auto it = metadata_entries.find("LUID");
+ if (it == metadata_entries.end()) {
+ return false;
+ }
+
+ uint64_t target_luid;
+ try {
+ target_luid = std::stoull(it->second);
+ } catch (const std::exception&) {
+ return false;
+ }
+
+ int device_count = 0;
+ if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
+ return false;
+ }
+
+ for (int i = 0; i < device_count; ++i) {
+ cudaDeviceProp prop;
+ if (cudaGetDeviceProperties(&prop, i) != cudaSuccess) {
+ continue;
+ }
+
+ // The LUID is an 8-byte value, valid on Windows when luidDeviceNodeMask is non-zero.
+ // We reconstruct the 64-bit integer representation from the raw bytes.
+ if (prop.luidDeviceNodeMask == 0) {
+ continue;
+ }
+
+ // Ensure the LUID is 8 bytes and reinterpret it directly as a uint64_t for comparison.
+ static_assert(sizeof(prop.luid) == sizeof(uint64_t), "cudaDeviceProp::luid should be 8 bytes");
+ uint64_t current_luid = *reinterpret_cast(prop.luid);
+
+ if (current_luid == target_luid) {
+ // Ampere architecture or newer is required.
+ return prop.major >= 8;
+ }
+ }
+
+ return false;
+ }
+
// Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.
// An EP created with this factory is expected to be able to execute a model with *all* supported
// hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among
@@ -579,11 +640,12 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
int16_t device_id = 0;
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
+
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&
- factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) {
+ factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id &&
+ factory->IsOrtHardwareDeviceSupported(device)) {
OrtKeyValuePairs* ep_options = nullptr;
OrtKeyValuePairs* ep_metadata = nullptr;
-
factory->ort_api.CreateKeyValuePairs(&ep_options);
factory->ort_api.CreateKeyValuePairs(&ep_metadata);
factory->ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str());
diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc
index a22d21d8d798b..bdeea726a2cf5 100644
--- a/onnxruntime/core/providers/webgpu/shader_helper.cc
+++ b/onnxruntime/core/providers/webgpu/shader_helper.cc
@@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha
ss << ",";
}
- auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : "";
- ss << "\n " << alignment << name << ": ";
+ // The actual variable type for the uniform variable depends on the data type (T) and length (N).
+ //
+ // For T in [i32, u32, f32]:
+ // - If N == 1, the type is simply i32, u32, or f32.
+ // - If 2 < N <= 4, the type is vecN, vecN, or vecN where N is the length.
+ // - If N > 4, the type is array, ceil(N / 4)>.
+ //
+ // For T is f16:
+ // - If N == 1 or N == 2, the type is u32.
+ // - If 2 < N <= 8, the type is vecX where X is ceil(N / 2).
+ // - If N > 8, the type is array, X> where X is ceil(N / 8).
+ //
+ // Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent
+ // 2 f16 values.
+
+ if (data_type == ProgramUniformVariableDataType::Float16) {
+ data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32
+ length = (length + 1) / 2; // each u32 can hold 2 f16 values
+ }
+ ss << "\n " << name << ": ";
if (length > 4) {
- if (data_type == ProgramUniformVariableDataType::Float16) {
- size_t array_size = (length + 7) / 8;
- ss << "array, " << array_size << ">";
- } else {
- size_t array_size = (length + 3) / 4;
- ss << "array, " << array_size << ">";
- }
+ size_t array_size = (length + 3) / 4;
+ ss << "array, " << array_size << ">";
} else if (length > 1) {
ss << "vec" << length << "<" << data_type << ">";
} else {
diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h
index 2aba2a59d157f..78c98ab26f5b8 100644
--- a/onnxruntime/core/providers/webgpu/shader_variable.h
+++ b/onnxruntime/core/providers/webgpu/shader_variable.h
@@ -17,18 +17,34 @@ template || std::is_same_v>>
std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) {
- // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20.
- if (var.rfind("uniforms.", 0) == 0) {
- if (rank > 4) {
- if constexpr (std::is_integral_v) {
- if (is_f16) {
- return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]");
+ if (var.starts_with("uniforms.")) {
+ if (is_f16) {
+ if (rank > 8) {
+ // array, N>
+ if constexpr (std::is_integral_v) {
+ return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]");
} else {
- return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]");
+ return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]");
+ }
+ } else if (rank > 2) {
+ // vecN
+ if constexpr (std::is_integral_v) {
+ return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 2, "])[", idx % 2, "]");
+ } else {
+ return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]");
}
} else {
- if (is_f16) {
- return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]");
+ // u32
+ if constexpr (std::is_integral_v) {
+ return MakeStringWithClassicLocale("bitcast>(", var, ")[", idx % 2, "]");
+ } else {
+ return MakeStringWithClassicLocale("bitcast>(", var, ")[(", idx, ") % 2]");
+ }
+ }
+ } else {
+ if (rank > 4) {
+ if constexpr (std::is_integral_v) {
+ return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]");
} else {
return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]");
}
diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc
index 4bd79a627df22..a9557f7b9aa87 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_context.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc
@@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
continue;
}
- bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;
-
- size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)];
+ // Calculate the size and alignment of the uniform variable.
+ //
// https://www.w3.org/TR/WGSL/#alignof
- size_t base_alignment = is_f16
- ? (length > 4 ? 16 : length > 2 ? 8
- : length * element_size)
- : (length > 2 ? 16 : length * element_size);
- size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16;
-
- current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment;
+ //
+ // For f16:
+ // - length > 8 : array, N> (align 16) (size 16 * N, N = ceil(length / 8))
+ // - length == 7 or 8: vec4 (align 16) (size 16)
+ // - length == 5 or 6: vec3 (align 16) (size 12)
+ // - length == 3 or 4: vec2 (align 8) (size 8)
+ // - length == 1 or 2: u32 (align 4) (size 4)
+ //
+ // For other types (i32, u32, f32):
+ // - length > 4 : array, N> (align 16) (size 16 * N, N = ceil(length / 4))
+ // - length == 4 : vec4 (align 16) (size 16)
+ // - length == 3 : vec3 (align 16) (size 12)
+ // - length == 2 : vec2 (align 8) (size 8)
+ // - length == 1 : T (align 4) (size 4)
+ //
+
+ const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;
+
+ size_t variable_alignment = 4; // default alignment for scalar types
+ size_t variable_size = 4; // default size for scalar types
+
+ if (is_f16) {
+ if (length > 6) {
+ variable_alignment = 16;
+ variable_size = 16 * ((length + 7) / 8);
+ } else if (length > 4) {
+ variable_alignment = 16;
+ variable_size = 12;
+ } else if (length > 2) {
+ variable_alignment = 8;
+ variable_size = 8;
+ }
+ } else {
+ if (length > 3) {
+ variable_alignment = 16;
+ variable_size = 16 * ((length + 3) / 4);
+ } else if (length > 2) {
+ variable_alignment = 16;
+ variable_size = 12;
+ } else if (length > 1) {
+ variable_alignment = 8;
+ variable_size = 8;
+ }
+ }
+ current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment;
uniform_and_offsets.emplace_back(uniform, current_offset);
- // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where
- // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4).
- // For float16 type, when length > 4, the uniform variable is of type array,N>, where
- // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4).
- size_t element_per_struct = is_f16 ? 8 : 4;
- current_offset +=
- length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size;
+ current_offset += variable_size;
}
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc
index 2b553aecbca6c..dfb2e33f8cb32 100644
--- a/onnxruntime/core/session/environment.cc
+++ b/onnxruntime/core/session/environment.cc
@@ -72,21 +72,23 @@ ProviderInfo_CUDA& GetProviderInfo_CUDA();
#endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)
namespace {
-// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison
+// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison.
static bool AreOrtMemoryInfosEquivalent(
const OrtMemoryInfo& left, const OrtMemoryInfo& right,
- bool match_name = true) {
+ bool match_name = true,
+ bool ignore_alignment = false) {
return left.mem_type == right.mem_type &&
- left.device == right.device &&
+ (ignore_alignment ? left.device.EqualIgnoringAlignment(right.device) : left.device == right.device) &&
(!match_name || strcmp(left.name, right.name) == 0);
}
std::vector::const_iterator FindExistingAllocator(const std::vector& allocators,
const OrtMemoryInfo& mem_info,
- bool match_name = true) {
+ bool match_name = true,
+ bool ignore_alignment = false) {
auto ite = std::find_if(std::begin(allocators),
std::end(allocators),
- [&mem_info, match_name](const AllocatorPtr& alloc_ptr) {
+ [&mem_info, match_name, ignore_alignment](const AllocatorPtr& alloc_ptr) {
// We want to do the equality checking of 2 OrtMemoryInfos sans the OrtAllocatorType field.
// This is because we want to avoid registering two allocators for the same device that just
// differ on OrtAllocatorType.
@@ -96,7 +98,8 @@ std::vector::const_iterator FindExistingAllocator(const std::vecto
// OrtDeviceAllocator (which is the only accepted value while registering a custom allocator).
// If we allowed this, it could potentially cause a lot of confusion as to which shared allocator
// to use for that device and we want to avoid having any ugly logic around this.
- return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, match_name);
+ return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info,
+ match_name, ignore_alignment);
});
return ite;
@@ -428,8 +431,25 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ
}
Environment::~Environment() {
- // need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed
+ // need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed.
+ // this is because any entry in shared_allocators_ wrapping an OrtAllocator from a plugin EP owns the OrtAllocator
+ // instance and will call Release on it. If the plugin EP has been freed the Release will fail.
shared_allocators_.clear();
+
+#if !defined(ORT_MINIMAL_BUILD)
+ // unregister any remaining EP libraries so they're cleaned up in a determistic way.
+ while (!ep_libraries_.empty()) {
+ auto it = ep_libraries_.begin();
+ ORT_IGNORE_RETURN_VALUE(UnregisterExecutionProviderLibrary(it->first));
+ }
+#endif
+}
+
+AllocatorPtr Environment::GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const {
+ std::lock_guard lock{mutex_};
+
+ auto it = FindExistingAllocator(shared_allocators_, mem_info, /*match_name*/ false, /*ignore_alignment*/ true);
+ return it != shared_allocators_.end() ? *it : nullptr;
}
Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator) {
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index f4f76a389030e..c0900c5ad28a0 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -1421,6 +1421,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
}
}
+ // We choose to convert initializers into OrtValues before partitioning here so plug-in EPs could
+ // take advantage of the initializers being in OrtValue format and not to deal with protobuf.
+ //
+ // The initializers data is transferred to an OrtValue. The original TensorProto is replaced
+ // with a TensorProto that has the same data type, shape and name. However, its external data
+ // is used in a non-standard way. The location is set to a string constant utils::kTensorProtoMemoryAddressTag,
+ // The file offset is set to the address of the OrtValue's data buffer, and the length is set to the size of the
+ // OrtValue's data buffer. Because this external location is non-standard, onnx code can not handle it, so we choose
+ // to do it as late as possible but before the partitioning so type and shape inference accesses the initializers
+ // before they are converted to OrtValues.
+ //
+ // If any transformations are applied later, they would not introduce any in-memory initializers,
+ // type and shape inference would run only on any newly added nodes and any new initializers
+ // will be converted at session finalization time.
+ //
+ // The conversion is performed using the following steps (within ConvertInitializersIntoOrtValues())
+ // constexpr const bool use_tensor_buffer_true = true;
+ // auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(),
+ // use_tensor_buffer_true);
+ // ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
+
+ ORT_RETURN_IF_ERROR_SESSIONID_(graph.ConvertInitializersIntoOrtValues());
+
// Do partitioning based on execution providers' capabilities.
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn,
session_options_.config_options, *session_logger_,
@@ -1984,13 +2007,15 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
// For now, this function only checks for invalid combination of DML EP with other EPs.
// TODO: extend this function to check for other invalid combinations of EPs.
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
- // DML EP is only allowed with CPU EP
+ // DML EP is not allowed with other GPU or NPU EPs.
+ // historical reason for this is unknown. relaxing the limit that it must only be used with the CPU EP to support
+ // scenarios where alternative EPs are CPU based (e.g. openvino).
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
if (has_dml_ep) {
- const auto& ep_list = execution_providers_.GetIds();
- for (const auto& ep : ep_list) {
- if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
- return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
+ for (const auto& ep : execution_providers_) {
+ if (ep->Type() != kDmlExecutionProvider && ep->GetDevice().Type() != OrtDevice::CPU) {
+ return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can only be used with CPU EPs.");
+ }
}
}
return Status::OK();
diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py
index e8e51db13bcd3..64c4ada07f28f 100644
--- a/onnxruntime/python/onnxruntime_inference_collection.py
+++ b/onnxruntime/python/onnxruntime_inference_collection.py
@@ -21,7 +21,7 @@
import onnxruntime
-def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
+def get_ort_device_type(device_type: str) -> int:
if device_type == "cuda":
return C.OrtDevice.cuda()
elif device_type == "cann":
@@ -32,8 +32,10 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
return C.OrtDevice.dml()
elif device_type == "webgpu":
return C.OrtDevice.webgpu()
- elif device_type == "ort":
- return C.get_ort_device(device_index).device_type()
+ elif device_type == "gpu":
+ return C.OrtDevice.gpu()
+ elif device_type == "npu":
+ return C.OrtDevice.npu()
else:
raise Exception("Unsupported device type: " + device_type)
@@ -765,7 +767,7 @@ def bind_input(self, name, device_type, device_id, element_type, shape, buffer_p
self._iobinding.bind_input(
name,
C.OrtDevice(
- get_ort_device_type(device_type, device_id),
+ get_ort_device_type(device_type),
C.OrtDevice.default_memory(),
device_id,
),
@@ -812,7 +814,7 @@ def bind_output(
self._iobinding.bind_output(
name,
C.OrtDevice(
- get_ort_device_type(device_type, device_id),
+ get_ort_device_type(device_type),
C.OrtDevice.default_memory(),
device_id,
),
@@ -823,7 +825,7 @@ def bind_output(
self._iobinding.bind_output(
name,
C.OrtDevice(
- get_ort_device_type(device_type, device_id),
+ get_ort_device_type(device_type),
C.OrtDevice.default_memory(),
device_id,
),
@@ -889,7 +891,7 @@ def _get_c_value(self) -> C.OrtValue:
return self._ortvalue
@classmethod
- def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue:
+ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0, vendor_id=-1) -> OrtValue:
"""
Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
@@ -897,6 +899,7 @@ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device
:param numpy_obj: The Numpy object to construct the OrtValue from
:param device_type: e.g. cpu, cuda, cann, cpu by default
:param device_id: device id, e.g. 0
+ :param vendor_id: The device's PCI vendor id. If provided, the device_type should be "gpu" or "npu".
"""
# Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
# is backed directly by the data buffer of the numpy object and so the numpy object
@@ -904,11 +907,7 @@ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device
return cls(
C.OrtValue.ortvalue_from_numpy(
numpy_obj,
- C.OrtDevice(
- get_ort_device_type(device_type, device_id),
- C.OrtDevice.default_memory(),
- device_id,
- ),
+ OrtDevice.make(device_type, device_id, vendor_id)._get_c_device(),
),
numpy_obj if device_type.lower() == "cpu" else None,
)
@@ -929,7 +928,7 @@ def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_ty
@classmethod
def ortvalue_from_shape_and_type(
- cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0
+ cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0, vendor_id: int = -1
) -> OrtValue:
"""
Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
@@ -938,7 +937,11 @@ def ortvalue_from_shape_and_type(
:param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16).
:param device_type: e.g. cpu, cuda, cann, cpu by default
:param device_id: device id, e.g. 0
+ :param vendor_id: If provided the device type should be "gpu" or "npu".
"""
+
+ device = OrtDevice.make(device_type, device_id, vendor_id)._get_c_device()
+
# Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
# This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
if isinstance(element_type, int):
@@ -946,11 +949,7 @@ def ortvalue_from_shape_and_type(
C.OrtValue.ortvalue_from_shape_and_onnx_type(
shape,
element_type,
- C.OrtDevice(
- get_ort_device_type(device_type, device_id),
- C.OrtDevice.default_memory(),
- device_id,
- ),
+ device,
)
)
@@ -958,11 +957,7 @@ def ortvalue_from_shape_and_type(
C.OrtValue.ortvalue_from_shape_and_type(
shape,
element_type,
- C.OrtDevice(
- get_ort_device_type(device_type, device_id),
- C.OrtDevice.default_memory(),
- device_id,
- ),
+ device,
)
)
@@ -1085,14 +1080,27 @@ def _get_c_device(self):
return self._ort_device
@staticmethod
- def make(ort_device_name, device_id):
- return OrtDevice(
- C.OrtDevice(
- get_ort_device_type(ort_device_name, device_id),
- C.OrtDevice.default_memory(),
- device_id,
+ def make(ort_device_name, device_id, vendor_id=-1):
+ if vendor_id < 0:
+ # backwards compatibility with predefined OrtDevice names
+ return OrtDevice(
+ C.OrtDevice(
+ get_ort_device_type(ort_device_name),
+ C.OrtDevice.default_memory(),
+ device_id,
+ )
+ )
+ else:
+ # generic. use GPU or NPU for ort_device_name and provide a vendor id.
+ # vendor id of 0 is valid in some cases (e.g. webgpu is generic and does not have a vendor id)
+ return OrtDevice(
+ C.OrtDevice(
+ get_ort_device_type(ort_device_name),
+ C.OrtDevice.default_memory(),
+ vendor_id,
+ device_id,
+ )
)
- )
def device_id(self):
return self._ort_device.device_id()
@@ -1100,6 +1108,9 @@ def device_id(self):
def device_type(self):
return self._ort_device.device_type()
+ def device_vendor_id(self):
+ return self._ort_device.vendor_id()
+
class SparseTensor:
"""
diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
index 958c9fc46bcd8..590e1ef3cdbdb 100644
--- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
+++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
@@ -99,6 +99,44 @@ TensorShape GetShape(const py::array& arr) {
return shape;
}
+AllocatorPtr GetSharedAllocator(const OrtDevice& device) {
+ auto& env = GetOrtEnv()->GetEnvironment();
+
+ OrtMemoryInfo mem_info("ignored", OrtDeviceAllocator, device);
+ return env.GetRegisteredSharedAllocator(mem_info);
+}
+
+MemCpyFunc CreateDataTransferMemCpy([[maybe_unused]] const OrtDevice& src_device,
+ [[maybe_unused]] const OrtDevice& dst_device) {
+#if defined(ORT_MINIMAL_BUILD)
+ // plugin EPs are not supported in a minimal build so there won't be any data transfers registered
+ return nullptr;
+#else
+
+ auto& env = GetOrtEnv()->GetEnvironment();
+ const DataTransferManager& data_transfer_manager = env.GetDataTransferManager();
+ const IDataTransfer* data_transfer = data_transfer_manager.GetDataTransfer(src_device, dst_device);
+ if (!data_transfer) {
+ return nullptr;
+ }
+
+ const auto copy_func = [src_device, dst_device, data_transfer](void* dst, const void* src, size_t bytes) {
+ OrtMemoryInfo src_memory_info("ignored", OrtDeviceAllocator, src_device);
+ OrtMemoryInfo dst_memory_info("ignored", OrtDeviceAllocator, dst_device);
+
+ // real shape doesn't matter as the Tensor instances here are temporary in order to be able to call CopyTensor.
+ // we set the shape to `bytes` and the data type to uint8_t to copy the correct number of bytes.
+ TensorShape shape = {narrow(bytes)};
+ Tensor src_tensor{DataTypeImpl::GetType(), shape, const_cast(src), src_memory_info};
+ Tensor dst_tensor{DataTypeImpl::GetType(), shape, dst, dst_memory_info};
+
+ ORT_THROW_IF_ERROR(data_transfer->CopyTensor(src_tensor, dst_tensor));
+ };
+
+ return copy_func;
+#endif
+}
+
void CpuToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
memcpy(dst, src, num_bytes);
}
@@ -158,9 +196,10 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes);
}
-const std::unordered_map* GetCudaToHostMemCpyFunction() {
- static std::unordered_map map{
- {OrtDevice::GPU, CudaToCpuMemCpy}};
+const std::unordered_map* GetCudaToHostMemCpyFunction() {
+ static std::unordered_map map{
+ {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0}, CudaToCpuMemCpy},
+ };
return ↦
}
@@ -215,9 +254,10 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_MIGraphX().MIGraphXMemcpy_DeviceToHost(dst, src, num_bytes);
}
-const std::unordered_map* GetMIGraphXToHostMemCpyFunction() {
- static std::unordered_map map{
- {OrtDevice::GPU, MIGraphXToCpuMemCpy}};
+const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) {
+ static std::unordered_map map{
+ {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, MIGraphXToCpuMemCpy},
+ };
return ↦
}
@@ -334,9 +374,10 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
}
-const std::unordered_map* GetDmlToHostMemCpyFunction() {
- static std::unordered_map map{
- {OrtDevice::GPU, DmlToCpuMemCpy}};
+const std::unordered_map* GetDmlToHostMemCpyFunction() {
+ static std::unordered_map map{
+ {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0}, DmlToCpuMemCpy},
+ };
return ↦
}
@@ -352,9 +393,10 @@ void CannToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_CANN().cannMemcpy_DeviceToHost(dst, src, num_bytes);
}
-const std::unordered_map* GetCannToHostMemCpyFunction() {
- static std::unordered_map map{
- {OrtDevice::NPU, CannToCpuMemCpy}};
+const std::unordered_map* GetCannToHostMemCpyFunction() {
+ static std::unordered_map map{
+ {OrtDevice{OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::HUAWEI, 0}, CannToCpuMemCpy},
+ };
return ↦
}
@@ -402,9 +444,10 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes);
}
-const std::unordered_map* GetRocmToHostMemCpyFunction() {
- static std::unordered_map map{
- {OrtDevice::GPU, RocmToCpuMemCpy}};
+const std::unordered_map* GetRocmToHostMemCpyFunction() {
+ static std::unordered_map map{
+ {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, RocmToCpuMemCpy},
+ };
return ↦
}
@@ -581,7 +624,7 @@ using OrtPybindSingleUseAllocatorPtr = std::shared_ptr& p_tensor,
- MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy) {
+ const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy) {
CopyDataToTensor(darray, npy_type, *p_tensor, mem_cpy_to_device);
}
-void CopyDataToTensor(const py::array& py_array, int npy_type, Tensor& tensor, MemCpyFunc mem_cpy_to_device) {
+void CopyDataToTensor(const py::array& py_array, int npy_type, Tensor& tensor, const MemCpyFunc& mem_cpy_to_device) {
CopyDataToTensor(reinterpret_cast(py_array.ptr()), npy_type, tensor, mem_cpy_to_device);
}
@@ -656,7 +699,7 @@ void CopyDataToTensor(const py::array& py_array, int npy_type, Tensor& tensor, M
// The numpy object owns the memory and needs to be alive until the corresponding OrtValue is in scope
static std::unique_ptr CreateTensor(const AllocatorPtr& alloc, const std::string& name_input,
PyArrayObject* pyObject, bool use_numpy_data_memory = true,
- MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy) {
+ const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy) {
PyArrayObject* darray = PyArray_GETCONTIGUOUS(pyObject);
ORT_ENFORCE(darray != nullptr, "The object must be a contiguous array for input '", name_input, "'.");
@@ -746,7 +789,8 @@ static void CreateSequenceOfTensors(AllocatorPtr alloc, const std::string& name_
// as the backing data buffer for the ORT Tensor where applicable (for numeric tensors)
// The numpy object owns the memory and needs to be alive until the corresponding OrtValue is in scope
static void CreateTensorMLValue(const AllocatorPtr& alloc, const std::string& name_input, PyArrayObject* pyObject,
- OrtValue* p_mlvalue, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy) {
+ OrtValue* p_mlvalue, bool use_numpy_data_memory = true,
+ const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy) {
auto p_tensor = CreateTensor(alloc, name_input, pyObject, use_numpy_data_memory, mem_cpy_to_device);
auto ml_tensor = DataTypeImpl::GetType();
@@ -994,9 +1038,10 @@ static void CreateGenericIterableMLValue(PyObject* iterator, AllocatorPtr alloc,
// Setting `use_numpy_data_memory` to `true` will ensure that the underlying numpy array buffer is directly used
// as the backing data buffer for the ORT Tensor where applicable (for numeric tensors)
// The numpy object owns the memory and needs to be alive until the corresponding OrtValue is in scope
-void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc, const std::string& name_input,
- const py::object& value, OrtValue* p_mlvalue, bool accept_only_numpy_array,
- bool use_numpy_data_memory, MemCpyFunc mem_cpy_to_device) {
+void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc,
+ const std::string& name_input, const py::object& value, OrtValue* p_mlvalue,
+ bool accept_only_numpy_array, bool use_numpy_data_memory,
+ const MemCpyFunc& mem_cpy_to_device) {
onnx::TypeProto type_proto;
if (PyObjectCheck_NumpyArray(value.ptr())) {
// The most frequent case: input comes as an array.
diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h
index e9bafea2ed1b5..7b65c0aae45c1 100644
--- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h
+++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h
@@ -42,22 +42,27 @@ MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type);
MLDataType OnnxTypeToOnnxRuntimeTensorType(int onnx_element_type);
-using MemCpyFunc = void (*)(void*, const void*, size_t);
-
+using MemCpyFunc = std::function;
using DataTransferAlternative = std::variant;
+// helpers to get allocator and IDataTransfer from Environment for plugin EP
+AllocatorPtr GetSharedAllocator(const OrtDevice& device);
+MemCpyFunc CreateDataTransferMemCpy(const OrtDevice& src_device, const OrtDevice& dst_device);
+
void CpuToCpuMemCpy(void*, const void*, size_t);
-void CopyDataToTensor(const pybind11::array& py_array, int npy_type, Tensor& tensor, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);
+void CopyDataToTensor(const pybind11::array& py_array, int npy_type, Tensor& tensor,
+ const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy);
pybind11::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager,
- const std::unordered_map* mem_cpy_to_host_functions);
+ const std::unordered_map* mem_cpy_to_host_functions);
-pybind11::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, const DataTransferManager* data_transfer_manager);
+pybind11::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value,
+ const DataTransferManager* data_transfer_manager);
pybind11::object AddNonTensorAsPyObj(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
- const std::unordered_map* mem_cpy_to_host_functions);
+ const std::unordered_map* mem_cpy_to_host_functions);
OrtMemoryInfo GetMemoryInfoPerDeviceType(const OrtDevice& ort_device);
@@ -69,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes);
void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes);
-const std::unordered_map* GetCudaToHostMemCpyFunction();
+const std::unordered_map* GetCudaToHostMemCpyFunction();
bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id);
@@ -87,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes);
void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes);
-const std::unordered_map* GetDmlToHostMemCpyFunction();
+const std::unordered_map* GetDmlToHostMemCpyFunction();
#endif
@@ -97,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes);
void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes);
-const std::unordered_map* GetMIGraphXToHostMemCpyFunction();
+const std::unordered_map* GetMIGraphXToHostMemCpyFunction();
AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id);
@@ -109,7 +114,7 @@ void CpuToCannMemCpy(void* dst, const void* src, size_t num_bytes);
void CannToCpuMemCpy(void* dst, const void* src, size_t num_bytes);
-const std::unordered_map* GetCannToHostMemCpyFunction();
+const std::unordered_map* GetCannToHostMemCpyFunction();
bool IsCannDeviceIdValid(const onnxruntime::logging::Logger& logger, int id);
@@ -127,17 +132,18 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes);
void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes);
-const std::unordered_map* GetRocmToHostMemCpyFunction();
+const std::unordered_map* GetRocmToHostMemCpyFunction();
#endif
void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc,
const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue,
- bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);
+ bool accept_only_numpy_array = false, bool use_numpy_data_memory = true,
+ const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy);
pybind11::object GetPyObjFromTensor(const OrtValue& rtensor,
const DataTransferManager* data_transfer_manager = nullptr,
- const std::unordered_map* mem_cpy_to_host_functions = nullptr);
+ const std::unordered_map* mem_cpy_to_host_functions = nullptr);
// The below two functions are used to convert OrtValue to numpy arrays
diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc
index d1d4d6f3cdad5..7234543eb14de 100644
--- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc
+++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc
@@ -23,42 +23,57 @@ std::unique_ptr OrtValueFromShapeAndType(const std::vector& s
MLDataType element_type,
const OrtDevice& device) {
AllocatorPtr allocator;
+
if (strcmp(GetDeviceName(device), CPU) == 0) {
allocator = GetAllocator();
- } else if (strcmp(GetDeviceName(device), CUDA) == 0) {
+ } else {
+#if !defined(ORT_MINIMAL_BUILD)
+ // prefer a shared allocator from the environment.
+ // these are provided by plugin EPs or custom allocators explicitly registered by the user.
+ allocator = GetSharedAllocator(device);
+#endif
+
+ if (!allocator) {
+ if (strcmp(GetDeviceName(device), CUDA) == 0) {
#ifdef USE_CUDA
- if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
- throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
- }
- allocator = GetCudaAllocator(device.Id());
+ if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
+ throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
+ }
+
+ allocator = GetCudaAllocator(device.Id());
#else
- throw std::runtime_error(
- "Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
- "Please use the CUDA package of OnnxRuntime to use this feature.");
+ throw std::runtime_error(
+ "Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
+ "Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
- } else if (strcmp(GetDeviceName(device), HIP) == 0) {
+ } else if (strcmp(GetDeviceName(device), HIP) == 0) {
#if USE_ROCM
- if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
- throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
- }
- allocator = GetRocmAllocator(device.Id());
+ if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
+ throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
+ }
+
+ allocator = GetRocmAllocator(device.Id());
#elif USE_MIGRAPHX
- allocator = GetMIGraphXAllocator(device.Id());
+ allocator = GetMIGraphXAllocator(device.Id());
#else
- throw std::runtime_error(
- "Can't allocate memory on the AMD device using this package of OnnxRuntime. "
- "Please use the ROCm package of OnnxRuntime to use this feature.");
+ throw std::runtime_error(
+ "Can't allocate memory on the AMD device using this package of OnnxRuntime. "
+ "Please use the ROCm package of OnnxRuntime to use this feature.");
#endif
- } else if (strcmp(GetDeviceName(device), DML) == 0) {
+ } else if (strcmp(GetDeviceName(device), DML) == 0) {
#if USE_DML
- allocator = GetDmlAllocator(device.Id());
+ allocator = GetDmlAllocator(device.Id());
#else
- throw std::runtime_error(
- "Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
- "Please use the DirectML package of OnnxRuntime to use this feature.");
+ throw std::runtime_error(
+ "Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
+ "Please use the DirectML package of OnnxRuntime to use this feature.");
#endif
- } else {
- throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
+ }
+ }
+
+ if (!allocator) {
+ throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
+ }
}
auto ml_value = std::make_unique();
@@ -90,7 +105,8 @@ void addOrtValueMethods(pybind11::module& m) {
if (device.Vendor() == OrtDevice::VendorIds::MICROSOFT) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
- // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
+ // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
+ // in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
} else
@@ -103,8 +119,10 @@ void addOrtValueMethods(pybind11::module& m) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
- // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
- CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCudaMemCpy);
+ // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
+ // in CUDA
+ CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
+ true, false, CpuToCudaMemCpy);
} else
#endif
#ifdef USE_ROCM
@@ -115,22 +133,34 @@ void addOrtValueMethods(pybind11::module& m) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
- // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
- CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
+ // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
+ // in ROCM
+ CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
+ true, false, CpuToRocmMemCpy);
} else
#endif
#if USE_MIGRAPHX
if (device.Vendor() == OrtDevice::VendorIds::AMD) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
- // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in MIGraphX
- CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToMIGraphXMemCpy);
+ // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
+ // in MIGraphX
+ CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
+ true, false, CpuToMIGraphXMemCpy);
} else
#endif
{
- throw std::runtime_error(
- "Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
- "Please use the CUDA package of OnnxRuntime to use this feature.");
+ // see if we can do the copy with an allocator and IDataTransfer registered by a plugin EP
+ auto allocator = GetSharedAllocator(device);
+ auto cpu_to_device_copy_fn = allocator ? CreateDataTransferMemCpy(OrtDevice{}, device) : nullptr;
+ if (cpu_to_device_copy_fn) {
+ CreateGenericMLValue(nullptr, allocator, "", array_on_cpu, ml_value.get(), true, false,
+ cpu_to_device_copy_fn);
+ } else {
+ throw std::runtime_error(
+ "Can't allocate memory on the device using this package of OnnxRuntime. "
+ "Please use the appropriate package of OnnxRuntime for your hardware to use this feature.");
+ }
}
} else if (device.Type() == OrtDevice::NPU && device.Vendor() == OrtDevice::VendorIds::HUAWEI) {
#ifdef USE_CANN
@@ -214,8 +244,16 @@ void addOrtValueMethods(pybind11::module& m) {
} else
#endif
{
- throw std::runtime_error(
- "Unsupported GPU device: Cannot find the supported GPU device.");
+ // see if we can do the copy with an allocator and IDataTransfer registered by a plugin EP
+ auto allocator = GetSharedAllocator(device);
+ auto cpu_to_device_copy_fn = allocator ? CreateDataTransferMemCpy(OrtDevice{}, device) : nullptr;
+ if (cpu_to_device_copy_fn) {
+ onnxruntime::python::CopyDataToTensor(py_values, values_type, *(ml_value->GetMutable()),
+ cpu_to_device_copy_fn);
+ } else {
+ throw std::runtime_error(
+ "Unsupported GPU device: Cannot find the supported GPU device.");
+ }
}
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index acf0681cf8752..03ad0185d1394 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -205,7 +205,7 @@ void AppendLoraParametersAsInputs(const RunOptions& run_options,
template
static py::object AddNonTensor(const OrtValue& val,
const DataTransferManager* /*data_transfer_manager*/,
- const std::unordered_map* /*mem_cpy_to_host_functions*/) {
+ const std::unordered_map* /*mem_cpy_to_host_functions*/) {
return py::cast(val.Get());
}
@@ -265,39 +265,65 @@ pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, cons
// pretty much does what a DataTransferManager does - copy data from device(s) to the host
py::object GetPyObjFromTensor(const OrtValue& ort_value,
const DataTransferManager* data_transfer_manager,
- const std::unordered_map* mem_cpy_to_host_functions) {
+ const std::unordered_map* mem_cpy_to_host_functions) {
ORT_ENFORCE(ort_value.IsTensor(), "This function only supports tensors");
const auto& tensor = ort_value.Get();
+ const auto& device = tensor.Location().device;
+
if (tensor.IsDataTypeString()) {
- ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::CPU, "Strings can only be on CPU");
+ ORT_ENFORCE(device.Type() == OrtDevice::CPU, "Strings can only be on CPU");
// Create a numpy array of strings (python objects) by copy/converting them
py::array result = StringTensorToNumpyArray(tensor);
return py::cast(result);
}
- const auto device_type = tensor.Location().device.Type();
+ const auto device_type = device.Type();
// Create an numpy array on top of the OrtValue memory, no copy
if (device_type == OrtDevice::CPU) {
py::array result = PrimitiveTensorToNumpyOverOrtValue(ort_value);
return py::cast(result);
}
- if (!data_transfer_manager && !mem_cpy_to_host_functions) {
- throw std::runtime_error(
- "GetPyObjFromTensor: Either data transfer manager or a "
- "function to copy data to the host is needed to convert non-CPU tensor to numpy array");
- }
-
py::array result;
if (data_transfer_manager != nullptr) {
result = PrimitiveTensorToNumpyFromDevice(ort_value, data_transfer_manager);
} else {
- auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type);
- ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(),
- "Unable to locate a function that can copy data to the host from the device");
- result = PrimitiveTensorToNumpyFromDevice(ort_value, mem_cpy_to_host->second);
+ bool copied = false;
+ if (mem_cpy_to_host_functions) {
+ auto it = std::find_if(mem_cpy_to_host_functions->begin(), mem_cpy_to_host_functions->end(),
+ [&device](const auto& entry) {
+ const auto& copy_device = entry.first;
+ // We're ignoring OrtDevice.Id() currently for historical reasons.
+ // The key to mem_cpy_to_host_functions was previously the device type (CPU/GPU/NPU).
+ // This changed to be OrtDevice to get the vendor id.
+ // Assumably it would be better to also match on device id, but that was not possible
+ // previously and to preserve existing behavior we keep the old logic and expect the
+ // copy function to handle the device id correctly.
+ return device.Type() == copy_device.Type() &&
+ device.MemType() == copy_device.MemType() &&
+ device.Vendor() == copy_device.Vendor();
+ });
+
+ if (it != mem_cpy_to_host_functions->end()) {
+ result = PrimitiveTensorToNumpyFromDevice(ort_value, it->second);
+ copied = true;
+ }
+ }
+
+ if (!copied) {
+ // see if we have a shared data transfer function from a plugin EP
+ auto device_to_cpu_copy_func = CreateDataTransferMemCpy(device, OrtDevice{});
+ if (device_to_cpu_copy_func) {
+ result = PrimitiveTensorToNumpyFromDevice(ort_value, device_to_cpu_copy_func);
+ } else {
+ throw std::runtime_error(
+ "GetPyObjFromTensor: Either data transfer manager or a "
+ "function to copy data to the host is needed to convert non-CPU tensor to numpy array");
+ }
+ }
}
+
return py::cast(result);
}
@@ -373,7 +399,7 @@ py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, co
template <>
py::object AddNonTensor(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
- const std::unordered_map* mem_cpy_to_host_functions) {
+ const std::unordered_map* mem_cpy_to_host_functions) {
const auto& seq_tensors = val.Get();
py::list py_list;
for (const auto& ort_value : seq_tensors) {
@@ -389,7 +415,7 @@ py::object AddNonTensor(const OrtValue& val,
py::object AddNonTensorAsPyObj(const OrtValue& val,
const DataTransferManager* data_transfer_manager,
- const std::unordered_map* mem_cpy_to_host_functions) {
+ const std::unordered_map* mem_cpy_to_host_functions) {
// Should be in sync with core/framework/datatypes.h
auto val_type = val.Type();
if (val_type->IsTensorSequenceType()) {
@@ -429,7 +455,7 @@ py::object AddNonTensorAsPyObj(const OrtValue& val,
}
py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager,
- const std::unordered_map* mem_cpy_to_host_functions) {
+ const std::unordered_map