diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index c84d34cfd3cbe..0660cc874ffb7 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -265,19 +265,23 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out); } +using AllocatorPtr = std::shared_ptr; +using AllocatorMap = std::map; + class CPUAllocator : public IAllocator { public: explicit CPUAllocator(const OrtMemoryInfo& memory_info) : IAllocator(memory_info) {} + // Creates a function local static and returns a shared pointer to it. + // Re-use in all places where we need a standalone CPUAllocator instance + static AllocatorPtr DefaultInstance(); + CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} void* Alloc(size_t size) override; void Free(void* p) override; }; -using AllocatorPtr = std::shared_ptr; -using AllocatorMap = std::map; - void* AllocatorDefaultAlloc(size_t size); void AllocatorDefaultFree(void* p); void* AllocatorDefaultAllocAligned(size_t size, size_t alignment); diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h index a071f3182faad..0ed427dfb7695 100644 --- a/include/onnxruntime/core/framework/ort_value.h +++ b/include/onnxruntime/core/framework/ort_value.h @@ -18,7 +18,7 @@ class SparseTensor; class TensorSeq; } // namespace onnxruntime -#endif +#endif // SHARED_PROVIDER /** Represents both tensors and non-tensors. @@ -37,8 +37,8 @@ struct OrtValue { type_ = type; } - void Init(void* pData, onnxruntime::MLDataType type, const std::function& deleter) { - data_.reset(pData, deleter); + void Init(void* pData, onnxruntime::MLDataType type, std::function deleter) { + data_.reset(pData, std::move(deleter)); type_ = type; } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 35b568e3f8e28..3b2fe20bd19a5 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -567,6 +567,13 @@ class Node { friend class Graph; Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph), can_be_saved_(true) {} + protected: +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // internal only method to allow selected classes to directly alter the input/output definitions and arg counts + // made protected to facilitate testing + Definitions& MutableDefinitions() noexcept; +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); @@ -588,9 +595,6 @@ class Node { #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // internal only method to allow selected classes to directly alter the input/output definitions and arg counts - Definitions& MutableDefinitions() noexcept; - // internal only method to allow selected classes to directly alter the links between nodes. Relationships& MutableRelationships() noexcept; @@ -721,11 +725,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /** Replaces the initializer tensor with the same name as the given initializer tensor. The replacement initializer tensor must have the same type and shape as the existing initializer tensor. + The new_initializer is expected to be either small or have external data reference stored in OrtValue. Note: This currently has linear time complexity. There is room for improvement but it would likely require changes to how initializer tensors are stored and tracked. */ - common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer); + common::Status ReplaceInitializedTensor(const ONNX_NAMESPACE::TensorProto& new_initializer, const OrtValue& ort_value); #if !defined(DISABLE_EXTERNAL_INITIALIZERS) /** This function takes externally provided data for initializers with external data @@ -745,6 +750,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Add an initializer tensor to the Graph. */ void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto); + + /// + /// Add initializer to the Graph. This method takes a tensor proto that contains + /// a data pointer to ort_value. For small tensors (LT utils::kSmallTensorExternalDataThreshold), + /// the data would still be contained within tensor_proto, and + /// OrtValue would be unallocated in this case, and not added to ortvalue_initializers_. + /// + /// tensor proto with external data pointing to OrtValue. + /// value that contains the initializer tensor. This may + /// be unallocated for small tensors. + Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const OrtValue& ort_value_initializer); #endif /** Remove the initializer tensor with the provided name from the Graph. */ @@ -769,7 +786,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. */ - bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; + bool GetOrtValueInitializer(const std::string& name, OrtValue& value, bool check_outer_scope = false) const; /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -1645,8 +1662,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // so they can be used to resolve outer scope dependencies when running BuildConnections for the subgraphs. common::Status SetOuterScopeNodeArgs(const std::unordered_set& outer_scope_node_args); - // Implementation for initializer replacement - Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); + /// + /// Replace initializer with new_initializer. + /// + /// + /// ort_value with data, may be empty + /// This is true when we replace the initializer with external data + /// with OrtValue from the customer, in which case we enforce that the original initializer must have external data + /// + Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, + OrtValue ort_value, bool must_replace_external); template // range-initializer returning std::string std::vector CreateNodeArgs(const StringRange& names, diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 6f07ead935f4a..7535b704cd4f0 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -57,8 +57,7 @@ InlinedVector> GenerateTransformers( const IExecutionProvider& execution_provider /*required by constant folding*/, const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr, - std::unordered_map>* p_buffered_tensors = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -89,8 +88,7 @@ InlinedVector> GenerateTransformersForMinimalB const IExecutionProvider& cpu_execution_provider, const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr, - std::unordered_map>* p_buffered_tensors = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 09bce9828aa33..5dc0a2efcbe93 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -215,7 +215,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } else { size_t total_size = static_cast(sequence_length) * static_cast(batch_beam_size); size_t total_size_bytes = total_size * sizeof(int); - AllocatorPtr buffer_allocator = std::make_shared(); + AllocatorPtr buffer_allocator = CPUAllocator::DefaultInstance(); // TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above. auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index bf866d67ffc0d..ad778fb7ef907 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -167,7 +167,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); - AllocatorPtr buffer_allocator = std::make_shared(); + AllocatorPtr buffer_allocator = CPUAllocator::DefaultInstance(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index 5307bb32de7d0..d5606f42b2dd1 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -257,7 +257,7 @@ OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) { deleter(p); }; - ort_value.Init(p_tensor.release(), DataTypeImpl::GetType(), deleter); + ort_value.Init(p_tensor.release(), DataTypeImpl::GetType(), std::move(deleter)); return ort_value; } diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index a58f5ee27b754..1014ddc2f2b69 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -99,6 +99,11 @@ void* AllocatorDefaultAlloc(size_t size) { return AllocatorDefaultAllocAligned(size, alignment); } +AllocatorPtr CPUAllocator::DefaultInstance() { + static AllocatorPtr instance = std::make_shared(); + return instance; +} + void* CPUAllocator::Alloc(size_t size) { const auto alignment = std::max(Info().device.GetAlignment(), MlasGetPreferredBufferAlignment()); return AllocatorDefaultAllocAligned(size, alignment); diff --git a/onnxruntime/core/framework/endian_utils.cc b/onnxruntime/core/framework/endian_utils.cc index 8b61aad769ae9..236dda6b2e9e4 100644 --- a/onnxruntime/core/framework/endian_utils.cc +++ b/onnxruntime/core/framework/endian_utils.cc @@ -48,6 +48,16 @@ void SwapByteOrderCopy(size_t element_size_in_bytes, } } +void SwapByteOrderInplace(size_t element_size_in_bytes, gsl::span bytes) { + ORT_ENFORCE(element_size_in_bytes > 0, "Expecting a positive element size"); + ORT_ENFORCE(bytes.size_bytes() % element_size_in_bytes == 0, "Expecting a match"); + if (element_size_in_bytes > 1) { + for (size_t offset = 0, lim = bytes.size_bytes(); offset < lim; offset += element_size_in_bytes) { + std::reverse(bytes.begin() + offset, bytes.begin() + offset + element_size_in_bytes); + } + } +} + namespace detail { Status CopyLittleEndian(size_t element_size_in_bytes, diff --git a/onnxruntime/core/framework/endian_utils.h b/onnxruntime/core/framework/endian_utils.h index 6f084d058d007..c0792302a7141 100644 --- a/onnxruntime/core/framework/endian_utils.h +++ b/onnxruntime/core/framework/endian_utils.h @@ -31,6 +31,21 @@ void SwapByteOrderCopy(size_t element_size_in_bytes, gsl::span source_bytes, gsl::span destination_bytes); +/** + * Swaps the byte order of the elements in the given byte span in place. + * + * This is a low-level function - please be sure to pass in valid arguments. + * In particular: + * - bytes should have a size that is a multiple of element_size_in_bytes. + * - element_size_in_bytes should be greater than zero. + * - bytes should not overlap with itself. + * + * @param element_size_in_bytes The size of an individual element, in bytes. + * @param source_bytes The source byte span. + */ +void SwapByteOrderInplace(size_t element_size_in_bytes, + gsl::span bytes); + namespace detail { /** diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 9a2991ab02730..2081b8c3c9344 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -17,6 +17,7 @@ #include "core/framework/resource_accountant.h" #include "core/graph/function.h" #include "core/graph/function_utils.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/model_saving_options.h" @@ -902,9 +903,9 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } // handle initializers - for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { - if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { - ep_graph.AddInitializedTensor(*initialized_tensor.second); + for (const auto& [name, _] : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(name) != nullptr) { + graph_utils::MakeInitializerCopyIfNotExist(graph, ep_graph, name); } } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6362a3169f3a3..7d0026cc35558 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -300,17 +300,13 @@ const std::vector& SessionState::GetPerValueAllocPlan() const return p_seq_exec_plan_->allocation_plan; } -Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, +Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, bool constant, bool sparse) { auto p = initialized_tensors_.insert({ort_value_index, ort_value}); if (!p.second) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated ort_value index:", ort_value_index, ". Do you have duplicated calls to SessionState::AddInitializedTensor function?"); - if (d != nullptr && d->f != nullptr) { - deleter_for_initialized_tensors_.insert_or_assign(ort_value_index, *d); - } - if (constant) { constant_initialized_tensors_.insert({ort_value_index, ort_value}); } @@ -1620,16 +1616,16 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string Status { - ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, &d, constant, sparse)); + ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, constant, sparse)); if (remove_initializers) { graph_.RemoveInitializedTensor(name); } return Status::OK(); }, logger_, data_transfer_mgr_, external_data_loader_mgr_, *p_seq_exec_plan_, session_options, - memory_profile_func, name_to_buffered_tensor_, graph_.GetPrepacked())); + memory_profile_func, graph_.GetPrepacked())); #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // Record Weight allocation info on device diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 964c059e529f9..9823cbf88f621 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -18,7 +18,6 @@ #include "core/common/logging/logging.h" #include "core/common/profiler.h" #include "core/framework/allocation_planner.h" -#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/external_data_loader_manager.h" #include "core/framework/execution_providers.h" @@ -102,9 +101,6 @@ class SessionState { AllocatorMap* parent_allocators = nullptr); ~SessionState() { - for (auto& kvp : deleter_for_initialized_tensors_) { - kvp.second.f(kvp.second.param); - } } // Graph viewer. CreateGraphInfo must have been called previously. @@ -143,12 +139,11 @@ class SessionState { /** * Adds an initialized tensor (weight) so that it can be used by the * execution frame to setup the appropriate OrtValue vectors. - * This function will take a shallow copy of d if d is not NULL. * If 'constant' is true the tensor value cannot be overridden by an input at runtime. * If 'sparse' is true the tensor value represents a densified weight that was initially stored in the model * as sparse tensor. */ - Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant, bool sparse); + Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, bool constant, bool sparse); /** * Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the @@ -310,10 +305,6 @@ class SessionState { const InlinedHashSet* GetToBeExecutedRange(gsl::span fetch_mlvalue_idxs) const; #endif - std::unordered_map>* GetMutableBufferedTensors() { - return &name_to_buffered_tensor_; - } - Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, bool remove_initializers = true, @@ -509,7 +500,6 @@ class SessionState { // This data structure is for uninitializing string tensors and // munmap memory region and close file descriptor - InlinedHashMap deleter_for_initialized_tensors_; InlinedVector weights_buffers_; std::optional p_seq_exec_plan_; @@ -607,12 +597,6 @@ class SessionState { // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; #endif - - // Holds the tensors which provide memory buffer for TensorProtos - // Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is - // allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release - // a instance's memory which may result large memory consumption, which is a tradeoff for speed. - std::unordered_map> name_to_buffered_tensor_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index cacd772b61d76..8f0713fcd7cb1 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -47,181 +47,138 @@ static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const T return Status::OK(); } -// deleter for external data tensors managed by an OrtValue; manages the release of -// the tensor's data buffer (which points to the external data) and the tensor itself -struct ExtDataValueDeleter { - OrtCallback ext_delete_cb; - Tensor* p_tensor; - void operator()(void*) noexcept { - if (ext_delete_cb.f) { - ext_delete_cb.f(ext_delete_cb.param); - } - - delete p_tensor; - } -}; - -// given a tensor proto with external data return an OrtValue with a tensor for -// that data; the pointers for the tensor data and the tensor itself are owned -// by the OrtValue's deleter. -// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and -// buffered_tensor is not null, buffered_tensor holds the real buffer pointed -// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter -// should release the buffer when tensor_proto is released. -static common::Status ExtDataTensorProtoToTensor(const Env& env, - const std::basic_string& proto_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor, OrtCallback& ext_data_deleter, - PrepackedWeightsForGraph& prepacked_for_graph, - Tensor* buffered_tensor = nullptr) { - ORT_ENFORCE(utils::HasExternalData(tensor_proto)); - - void* ext_data_buf = nullptr; - SafeInt ext_data_len = 0; - ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, - ext_data_buf, ext_data_len, ext_data_deleter, - buffered_tensor, &prepacked_for_graph)); - if constexpr (endian::native != endian::little) { - if (!proto_path.empty() && (proto_path.compare(onnxruntime::utils::kTensorProtoMemoryAddressTag) != 0)) { - utils::ConvertRawDataInTensorProto(const_cast(&tensor_proto), ext_data_buf, ext_data_len); - } - } - - // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be - // avoided if the Tensor class implements the do-nothing behavior when given a - // nullptr for the allocator argument - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); - tensor = Tensor(type, tensor_shape, ext_data_buf, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - - return common::Status::OK(); -} - -// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and -// buffered_tensor is not null, buffered_tensor holds the real buffer pointed -// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter -// should release the buffer when tensor_proto is released. +/** + * @brief Deserializes a TensorProto into an OrtValue. + * + * This function handles the complexities of deserializing a tensor, including + * managing memory allocation, handling external data, and transferring data + * between different devices (e.g., CPU to GPU). It can use a pre-allocated + * memory buffer or an allocator to manage the tensor's memory. + * + * @param env The environment object, providing access to logging and other services. + * @param proto_path The file path of the ONNX model, used for resolving external data paths. + * @param tensor_proto The TensorProto message to deserialize. + * @param memory_buffer Optional. A raw memory buffer that is pre-allocated for the tensor. + * If provided, `alloc` must be null. + * @param alloc Optional. An allocator to use for allocating the tensor's memory. + * If provided, `memory_buffer` must be null. + * @param default_cpu_alloc The default CPU allocator, used for intermediate buffers if needed + * (e.g., when copying from CPU to another device). + * @param[out] ort_value The OrtValue to be populated with the deserialized tensor data. + * @param data_transfer_mgr The manager responsible for copying tensor data between different memory locations/devices. + * @param external_data_loader_mgr The manager for handling custom external data loaders. + * @param prepacked_for_graph Reference to an object managing prepacked weights for the graph. + * @param use_device_allocator_for_initializers A flag indicating whether to use the device-specific allocator + * directly for initializers, potentially bypassing arenas. + * @return common::Status indicating success or failure of the deserialization process. + * Returns an error status if both `memory_buffer` and `alloc` are provided or if both are null (unless external data on CPU allows mmap), + * if string tensors are attempted to be copied to non-CPU devices, or if any underlying + * data loading, allocation, or copying operation fails. + */ static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, + const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* memory_buffer, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, PrepackedWeightsForGraph& prepacked_for_graph, - bool use_device_allocator_for_initializers = false, - Tensor* buffered_tensor = nullptr) { - if (bool(alloc) == (m != nullptr)) { + bool use_device_allocator_for_initializers = false) { + if (bool(alloc) == (memory_buffer != nullptr)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } - ORT_RETURN_IF(buffered_tensor && !utils::HasExternalData(tensor_proto), - "With buffered tensor, tensor proto must use external location and point to buffered tensor"); - - // Get shape and type of the tensor, and allocate the empty tensor TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - std::unique_ptr p_tensor; + Tensor tensor; - auto& memory_info = (alloc != nullptr) ? alloc->Info() : m->GetAllocInfo(); - auto device_type = memory_info.device.Type(); + // Get shape and type of the tensor, and allocate the empty tensor + static const auto default_cpu_device = OrtDevice(); + const auto& memory_info = (alloc != nullptr) ? alloc->Info() : memory_buffer->GetAllocInfo(); + const auto device = memory_info.device; if (utils::HasExternalData(tensor_proto)) { auto external_data_loader = external_data_loader_mgr.GetExternalDataLoader(memory_info); if (external_data_loader) { - // if custom external data loader is used, always allocate memory on device - p_tensor - ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); - + // if custom external data loader is used, always allocate memory on device + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); ORT_RETURN_IF_ERROR(utils::LoadExtDataToTensorFromTensorProto(env, proto_path, tensor_proto, - *external_data_loader, *p_tensor)); + *external_data_loader, tensor)); - Tensor::InitOrtValue(std::move(*p_tensor), ort_value); + Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); - } else if (device_type == OrtDevice::CPU) { + } else if (device == default_cpu_device) { // for external initializer on CPU we will use mmap for large initializers so don't need to allocate memory in advance - p_tensor = std::make_unique(); // NB: The file containing external data for the tensor is mmap'd. If the tensor will be used on CPU we can - // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called - // TensorProtoToTensor it would copy the data, causing unnecessary overhead - OrtCallback ext_data_deleter; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, - ext_data_deleter, prepacked_for_graph, - buffered_tensor)); - - ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; - MLDataType ml_tensor_type = DataTypeImpl::GetType(); - ort_value.Init(p_tensor.release(), ml_tensor_type, deleter); + // utilize the mmap'd buffer directly. + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path, tensor_proto, + ort_value, + &prepacked_for_graph)); return common::Status::OK(); - } else { // non-cpu tensor - if (tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + } else { // non-cpu tensor or tensor in a cpu accessible memory + if (utils::HasString(tensor_proto)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators"); } // deserialize to CPU first for non-CPU allocator, then copy to device // for external initializer load on non-CPU device: - // 1. allocate memory on device - p_tensor - // 2. load initializer into CPU memory - p_deserialize_tensor, + // 1. allocate memory on device - tensor + // 2. load initializer into CPU memory - deserialized_value, // we will use mmap so no need to allocate memory on CPU in advance - // 3. copy tensor from CPU to device - p_deserialize_tensor -> p_tensor - ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + // 3. copy tensor from CPU to device - deserialized_value -> tensor -> ort_value + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); - std::unique_ptr p_deserialize_tensor = std::make_unique(type, TensorShape(), default_cpu_alloc); + OrtValue deserialized_value; + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path, tensor_proto, + deserialized_value, + &prepacked_for_graph)); - OrtCallback ext_data_deleter; - std::optional scoped_ort_callback_invoker; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter, prepacked_for_graph, - buffered_tensor)); - scoped_ort_callback_invoker.emplace(ext_data_deleter); - // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. - - return CopyTensorFromCPUToDevice(data_transfer_mgr, p_deserialize_tensor, p_tensor, ort_value); + return CopyTensorFromCPUToDevice(data_transfer_mgr, deserialized_value.Get(), + std::move(tensor), ort_value); } } else { - // for internal initializer, always allocate memory on device - p_tensor - ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + // for internal initializer, always allocate memory on device - tensor + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); - if (device_type == OrtDevice::CPU) { + if (device == default_cpu_device) { // deserialize directly to CPU tensor - ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_tensor)); - auto ml_tensor = DataTypeImpl::GetType(); - ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, tensor)); + Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); } else { // non-cpu tensor - if (tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + if (utils::HasString(tensor_proto)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators"); } // deserialize to CPU first for non-CPU allocator, then copy // for internal initializer - // 1. allocate memory on CPU - p_deserialize_tensor - // 2. deserialize tensor_probo into a preallocated tensor (p_deserialize_tensor) - // 3. copy tensor from CPU to device - p_deserialize_tensor -> p_tensor - std::unique_ptr p_deserialize_tensor; - ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, default_cpu_alloc, p_deserialize_tensor)); - - ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); - // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. - - return CopyTensorFromCPUToDevice(data_transfer_mgr, p_deserialize_tensor, p_tensor, ort_value); + // 1. allocate memory on CPU - deserialized_tensor + // 2. deserialize tensor_proto into a preallocated tensor (deserialized_tensor) + // 3. copy tensor from CPU to device - deserialized_tensor -> tensor (allocated above) -> ort_value + Tensor deserialized_tensor; + ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, + default_cpu_alloc, deserialized_tensor)); + + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, deserialized_tensor)); + return CopyTensorFromCPUToDevice(data_transfer_mgr, deserialized_tensor, std::move(tensor), ort_value); } } } -common::Status AllocateTensor(const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, +common::Status AllocateTensor(const onnxruntime::MemBuffer* memory_buffer, + Tensor& tensor, const onnxruntime::DataTypeImpl* const& type, onnxruntime::TensorShape& tensor_shape, bool use_device_allocator_for_initializers, const onnxruntime::AllocatorPtr& alloc) { - if (m != nullptr) { - p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); - if (m->GetLen() < p_tensor->SizeInBytes()) { + if (memory_buffer != nullptr) { + tensor = Tensor{type, tensor_shape, memory_buffer->GetBuffer(), memory_buffer->GetAllocInfo()}; + if (memory_buffer->GetLen() < tensor.SizeInBytes()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ", - p_tensor->SizeInBytes(), ", Got ", m->GetLen()); + tensor.SizeInBytes(), ", Got ", memory_buffer->GetLen()); } } else { - return AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, alloc, p_tensor); + return AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, alloc, tensor); } return common::Status::OK(); } @@ -231,37 +188,36 @@ common::Status AllocateTensorOnDeviceOrMemory( onnxruntime::TensorShape& tensor_shape, const onnxruntime::DataTypeImpl* const& type, const onnxruntime::AllocatorPtr& alloc, - std::unique_ptr& p_tensor) { + Tensor& tensor) { if (use_device_allocator_for_initializers) { void* tensor_buffer = nullptr; ORT_RETURN_IF_ERROR(AllocateBufferUsingDeviceAllocatorFromShapeAndType(tensor_shape, type, alloc, tensor_buffer)); - p_tensor = std::make_unique(type, tensor_shape, tensor_buffer, alloc); + tensor = Tensor{type, tensor_shape, tensor_buffer, alloc}; } else { // If the provided allocator is an arena-based allocator, the call to Alloc() will tap into memory from the arena // (may expand it if there isn't a chunk that can be allotted to the memory request). // If the provided allocator is non-arena based, the device specific Alloc() call will be used to allocate the necessary memory. - p_tensor = std::make_unique(type, tensor_shape, alloc); + tensor = Tensor{type, tensor_shape, alloc}; } return common::Status::OK(); } common::Status CopyTensorFromCPUToDevice( const onnxruntime::DataTransferManager& data_transfer_mgr, - std::unique_ptr& p_deserialize_tensor, - std::unique_ptr& p_tensor, + const Tensor& deserialized_tensor, + Tensor&& tensor, OrtValue& ort_value) { - Status copy_status = data_transfer_mgr.CopyTensor(*p_deserialize_tensor, *p_tensor); + Status copy_status = data_transfer_mgr.CopyTensor(deserialized_tensor, tensor); if (!copy_status.IsOK()) { if (copy_status.ErrorMessage().empty()) { // The windows execution provider does not return any error message today for CopyTensor since it is // not implemented yet. That's the reason we're adding our own error message so that we can debug better. return Status(copy_status.Category(), copy_status.Code(), - "Failed to copy tensor to " + p_tensor->Location().ToString()); + "Failed to copy tensor to " + tensor.Location().ToString()); } return copy_status; } else { - auto ml_tensor = DataTypeImpl::GetType(); - ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); } } @@ -279,7 +235,6 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, PrepackedWeightsForGraph& prepacked_for_graph) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -298,13 +253,13 @@ common::Status SaveInitializedTensors( if (!ort_value_name_idx_map.GetIdx(name, ort_value_index).IsOK()) { retval = false; } else { - const auto& planned_mem_info = exec_plan.GetLocation(ort_value_index); + const auto& planned_mem_device = exec_plan.GetLocation(ort_value_index); const auto& user_mem_info = it->second->Get().Location(); - retval = user_mem_info.device == planned_mem_info; + retval = user_mem_info.device == planned_mem_device; if (!retval) { LOGS(logger, WARNING) << "Cannot use user supplied initializer with name: (" << name << ") because the ORT planned memory location device " - << planned_mem_info.ToString() + << planned_mem_device.ToString() << " ) is different from what is supplied (" << user_mem_info.ToString() << ")"; } } @@ -319,7 +274,7 @@ common::Status SaveInitializedTensors( InlinedHashSet user_supplied_initializer_ids; // set containing the ort value ids of all user supplied initializers id_to_initialized_tensor.reserve(initialized_tensor_set.size()); - user_supplied_initializer_ids.reserve(initialized_tensor_set.size()); + user_supplied_initializer_ids.reserve(session_options.initializers_to_share_map.size()); for (const auto& entry : initialized_tensor_set) { int ort_value_index; @@ -330,6 +285,8 @@ common::Status SaveInitializedTensors( id_to_initialized_tensor[ort_value_index] = entry.second; } + static const auto default_cpu_device = OrtDevice(); + // tensors requiring a specific allocation order are traced first, to ensure they are allocated in order // NB1: vector with init allocation order may contain a subset of all tensors (or none at all) // NB2: only skip tracing and planning memory when data is external (i.e mmap) and on CPU. @@ -339,10 +296,21 @@ common::Status SaveInitializedTensors( const auto entry = initialized_tensors_to_allocate.find(ort_value_index); ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(), "OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors"); - if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) { - // can not trace string tensor - ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor"); - ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second)); + const auto* tensor_proto = entry->second; + + // We trace to allocate a single buffer using the planner. This reduces fragmentation. + // We do not trace the following values because it would add to the memory consumption. + // - Values that are on OrtDevice() (default CPU). + // - Values that are external and mapped from disk. We let the OS manage the memory. + // - we do not trace values that are in memory because they may be sitting on top of the user allocated + // memory. + const bool trace_allocation = (exec_plan.GetLocation(ort_value_index) != default_cpu_device) || + !utils::HasExternalData(*tensor_proto); + + if (trace_allocation) { + // can not trace string tensor, and they exist only on CPU + ORT_ENFORCE(!utils::HasString(*tensor_proto), "Can not trace string tensor"); + ORT_RETURN_IF_ERROR(planner.Trace(ort_value_index, tensor_proto)); } initialized_tensors_to_allocate.erase(entry); } @@ -352,7 +320,7 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { continue; } - if (entry.second->data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + if (utils::HasString(*entry.second)) { // do not trace string tensor continue; } @@ -374,8 +342,6 @@ common::Status SaveInitializedTensors( << i.second << " bytes for " << i.first.ToString() << std::endl; } - OrtCallback deleter{nullptr, nullptr}; - // 3. create weight tensors based on weights buffer for (const auto& entry : id_to_initialized_tensor) { // We check for cancellation for every initializer since mapping from disk can be costly @@ -397,39 +363,50 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; - - } else if (graph.GetOrtValueInitializer(name, ort_value)) { - // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); - std::optional m; + std::optional memory_buffer; AllocatorPtr alloc; // TODO: if the tensor need be copied, does it have enough room? - ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, m, alloc)); - bool use_device_allocator_for_initializers = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; - - Tensor* p_tensor = nullptr; - auto buffered_tensors_iter = buffered_tensors.find(name); - if (buffered_tensors_iter != buffered_tensors.end()) { - p_tensor = buffered_tensors_iter->second.get(); - } - - Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, - default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, - prepacked_for_graph, - use_device_allocator_for_initializers, p_tensor); - if (!st.IsOK()) { - std::ostringstream oss; - oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); - return Status(st.Category(), st.Code(), oss.str()); - } - - if (p_tensor != nullptr) { - // p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here. - ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release()); - buffered_tensors.erase(buffered_tensors_iter); + ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc)); + const bool use_device_allocator_for_initializers = + session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; + + // Check if we already have an OrtValue for this initializer on CPU + if (OrtValue ort_value_from_graph; + graph.GetOrtValueInitializer(name, ort_value_from_graph)) { + const auto& memory_info = (alloc != nullptr) ? alloc->Info() : memory_buffer->GetAllocInfo(); + if (memory_info.device == default_cpu_device) { + // This is on CPU use directly from the graph + ort_value = std::move(ort_value_from_graph); + } else { + TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum( + tensor_proto.data_type()) + ->GetElementType(); + Tensor tensor; + ORT_RETURN_IF_ERROR(AllocateTensor((memory_buffer) ? &*memory_buffer : nullptr, tensor, type, + tensor_shape, use_device_allocator_for_initializers, + alloc)); + ORT_RETURN_IF_ERROR(CopyTensorFromCPUToDevice(data_transfer_mgr, + ort_value_from_graph.Get(), + std::move(tensor), ort_value)); + } + } else { + // We need to deserialize the tensor proto into an OrtValue + // using the preallocated buffer or allocator. + + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (memory_buffer.has_value()) ? &*memory_buffer : nullptr, alloc, + default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, + prepacked_for_graph, + use_device_allocator_for_initializers); + if (!st.IsOK()) { + std::ostringstream oss; + oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); + return Status(st.Category(), st.Code(), oss.str()); + } } } @@ -442,9 +419,9 @@ common::Status SaveInitializedTensors( const bool constant = graph.IsConstantInitializer(name, /* check_outer_scope */ false); #if !defined(DISABLE_SPARSE_TENSORS) const bool sparse = graph.GetGraph().IsSparseInitializer(name); - ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, sparse)); + ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, constant, sparse)); #else - ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, false)); + ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, constant, false)); #endif } diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 17400c45e5f32..3428b38b389a8 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -36,7 +36,7 @@ class Logger; namespace session_state_utils { using SaveTensorFunction = std::function; + bool constant, bool sparse)>; using MemoryProfileFunction = std::function; common::Status SaveInitializedTensors( @@ -51,12 +51,11 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, PrepackedWeightsForGraph& prepacked_for_graph); common::Status AllocateTensor( - const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, + const onnxruntime::MemBuffer* memory_buffer, + Tensor& p_tensor, const onnxruntime::DataTypeImpl* const& type, onnxruntime::TensorShape& tensor_shape, bool use_device_allocator_for_initializers, @@ -67,12 +66,12 @@ common::Status AllocateTensorOnDeviceOrMemory( onnxruntime::TensorShape& tensor_shape, const onnxruntime::DataTypeImpl* const& type, const onnxruntime::AllocatorPtr& alloc, - std::unique_ptr& p_tensor); + Tensor& p_tensor); common::Status CopyTensorFromCPUToDevice( const onnxruntime::DataTransferManager& data_transfer_mgr, - std::unique_ptr& p_deserialize_tensor, - std::unique_ptr& p_tensor, + const Tensor& deserialized_tensor, + Tensor&& tensor, OrtValue& ort_value); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index 92c264e57279c..ad88149c89b81 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -47,12 +47,13 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { } else { buffer = alloc->Alloc(peak_size); } - weights_buffers_.push_back(BufferUniquePtr(buffer, BufferDeleter(alloc))); + + auto buffer_ptr = BufferUniquePtr(buffer, BufferDeleter(std::move(alloc))); auto kvp = buffers_.insert(std::make_pair(location, buffer)); if (!kvp.second) { - alloc->Free(buffer); return Status(common::ONNXRUNTIME, common::FAIL, "duplicated location"); } + weights_buffers_.push_back(std::move(buffer_ptr)); planned_memory_sizes_in_byte[location] += peak_size; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 94a2a6677358e..ee815b5e722cc 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -22,7 +22,6 @@ #include "core/framework/tensor.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/allocator.h" -#include "core/framework/callback.h" #include "core/framework/data_types.h" #include "core/platform/path_lib.h" #include "core/framework/to_tensor_proto_element_type.h" @@ -172,13 +171,21 @@ DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::vector& unpacked_tensor) { - std::basic_string external_file_path; + PathString external_file_path; onnxruntime::FileOffsetType file_offset; SafeInt tensor_byte_size; ORT_RETURN_IF_ERROR( GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size)); unpacked_tensor.resize(tensor_byte_size); + + if (external_file_path == kTensorProtoMemoryAddressTag) { + // The external data is in the same memory as the tensor proto. + // The offset is the address of the data. + std::memcpy(unpacked_tensor.data(), reinterpret_cast(file_offset), tensor_byte_size); + return Status::OK(); + } + ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( external_file_path.c_str(), file_offset, @@ -216,7 +223,7 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo tensor->SizeInBytes(), ", Got ", m->GetLen()); } } else { - tensor = std::make_unique(type, tensor_shape, alloc); + tensor = std::make_unique(type, tensor_shape, std::move(alloc)); } ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensor)); @@ -230,16 +237,55 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { +bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) { + if (HasExternalData(ten_proto)) { + // Retrieve the external data info + for (const auto& entry : ten_proto.external_data()) { + if (entry.key() == "location") { + PathString location = ToWideString(entry.value()); + return location == kTensorProtoMemoryAddressTag; + } + } + } + + return false; // No external data in memory +} + +Status TensorProtoWithExternalDataToTensorProto( + const ONNX_NAMESPACE::TensorProto& ten_proto, + const std::filesystem::path& model_path, + ONNX_NAMESPACE::TensorProto& new_tensor_proto) { + // Check if the input tensor has external data + ORT_RETURN_IF_NOT(HasExternalData(ten_proto), "Input tensor does not have external data."); + + // Copy the metadata from the source tensor to the new tensor + ONNX_NAMESPACE::TensorProto result; + result.set_name(ten_proto.name()); + result.set_data_type(ten_proto.data_type()); + result.mutable_dims()->CopyFrom(ten_proto.dims()); + + // Load the external data into memory + std::vector unpacked_data; + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(ten_proto, model_path, unpacked_data)); + + // Set the raw data in the new tensor + result.set_raw_data(unpacked_data.data(), unpacked_data.size()); + + new_tensor_proto = std::move(result); + + return Status::OK(); +} + Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size, ExternalDataInfo::PrepackedInfos* prepacked_infos) { - ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), + ORT_RETURN_IF_NOT(HasExternalData(tensor_proto), "Tensor does not have external data to read from."); - ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), + ORT_RETURN_IF(!HasDataType(tensor_proto) || HasString(tensor_proto), "External data type cannot be UNDEFINED or STRING."); std::unique_ptr external_data_info; @@ -247,10 +293,10 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const auto& location = external_data_info->GetRelPath(); - external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) - : (tensor_proto_dir / location); + external_file_path = location == kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); + ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); const size_t external_data_length = external_data_info->GetLength(); ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, "TensorProto: ", tensor_proto.name(), @@ -270,38 +316,22 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -void ConvertRawDataInTensorProto(TensorProto* tensor, - void* ext_data_buf, - size_t ext_data_len) { +void ConvertRawDataInTensorProto(TensorProto& tensor) { size_t element_size = 1; - char* bytes = NULL; + void* bytes = NULL; size_t num_elements = 0; - if (ext_data_buf && !ext_data_len) { - return; - } - switch (tensor->data_type()) { + + switch (tensor.data_type()) { case TensorProto_DataType_FLOAT: - bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); - num_elements = tensor->float_data_size(); + bytes = tensor.mutable_float_data()->mutable_data(); + num_elements = tensor.float_data_size(); element_size = sizeof(float); break; - case TensorProto_DataType_INT32: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); - element_size = sizeof(int32_t); - break; - - case TensorProto_DataType_UINT32: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); - element_size = sizeof(uint32_t); - break; - case TensorProto_DataType_UINT8: case TensorProto_DataType_INT8: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); + bytes = tensor.mutable_int32_data()->mutable_data(); + num_elements = tensor.int32_data_size(); element_size = sizeof(uint8_t); break; @@ -309,56 +339,52 @@ void ConvertRawDataInTensorProto(TensorProto* tensor, case TensorProto_DataType_INT16: case TensorProto_DataType_FLOAT16: case TensorProto_DataType_BFLOAT16: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); - element_size = sizeof(uint16_t); + case TensorProto_DataType_INT32: + bytes = tensor.mutable_int32_data()->mutable_data(); + num_elements = tensor.int32_data_size(); + // We are setting this to int32_t size because we need to swap all 4 bytes + // to represent 16 bits within 32 bits correctly on a LE/BE system. + element_size = sizeof(int32_t); break; + // uint32_t is stored in uint64_t + case TensorProto_DataType_UINT32: case TensorProto_DataType_UINT64: - bytes = reinterpret_cast(tensor->mutable_uint64_data()->mutable_data()); - num_elements = tensor->uint64_data_size(); + bytes = tensor.mutable_uint64_data()->mutable_data(); + num_elements = tensor.uint64_data_size(); element_size = sizeof(uint64_t); break; - case TensorProto_DataType_DOUBLE: - bytes = reinterpret_cast(tensor->mutable_double_data()->mutable_data()); - num_elements = tensor->double_data_size(); - element_size = sizeof(double); - break; - case TensorProto_DataType_INT64: - bytes = reinterpret_cast(tensor->mutable_int64_data()->mutable_data()); - num_elements = tensor->int64_data_size(); + bytes = tensor.mutable_int64_data()->mutable_data(); + num_elements = tensor.int64_data_size(); element_size = sizeof(int64_t); break; + case TensorProto_DataType_DOUBLE: + bytes = tensor.mutable_double_data()->mutable_data(); + num_elements = tensor.double_data_size(); + element_size = sizeof(double); + break; + case TensorProto_DataType_COMPLEX64: - bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); - num_elements = tensor->float_data_size(); + bytes = tensor.mutable_float_data()->mutable_data(); + num_elements = tensor.float_data_size(); element_size = sizeof(float); break; } - if (tensor->has_raw_data()) { - num_elements = (tensor->raw_data().size()) / element_size; - bytes = const_cast(tensor->mutable_raw_data()->c_str()); - } if (element_size == 1) { return; } - if (ext_data_buf) { - ORT_ENFORCE(ext_data_len % element_size == 0); - num_elements = ext_data_len / element_size; - bytes = reinterpret_cast(ext_data_buf); - } - for (size_t i = 0; i < num_elements; ++i) { - char* start_byte = bytes + i * element_size; - char* end_byte = start_byte + element_size - 1; - for (size_t count = 0; count < element_size / 2; ++count) { - std::swap(*start_byte++, *end_byte--); - } + + if (tensor.has_raw_data()) { + num_elements = tensor.raw_data().size() / element_size; + bytes = tensor.mutable_raw_data()->data(); } - return; + + gsl::span span = gsl::make_span(reinterpret_cast(bytes), num_elements * element_size); + SwapByteOrderInplace(element_size, span); } #if !defined(ORT_MINIMAL_BUILD) @@ -984,26 +1010,10 @@ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enu #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif -class AutoDelete { - public: - OrtCallback d{nullptr, nullptr}; - AutoDelete() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(AutoDelete); - ~AutoDelete() { - if (d.f != nullptr) { - d.f(d.param); - } - } -}; - -static void DeleteCharArray(void* param) noexcept { - auto arr = reinterpret_cast(param); - delete[] arr; -} #if !defined(__wasm__) static Status GetFileContent(const Env& env, const std::filesystem::path& file_path, FileOffsetType offset, - size_t length, void*& raw_buffer, OrtCallback& deleter) { + size_t length, IAllocatorUniquePtr& external_data) { // query length if it is 0 if (length == 0) { // The return type of std::filesystem::file_size is uintmax_t which could be bigger than size_t @@ -1015,8 +1025,9 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Env::MappedMemoryPtr mapped_memory{}; auto status = env.MapFileIntoMemory(file_path.native().c_str(), offset, length, mapped_memory); if (status.IsOK()) { - deleter = mapped_memory.get_deleter().callback; - raw_buffer = mapped_memory.release(); + IAllocatorUniquePtr raw_buffer(mapped_memory.release(), + mapped_memory.get_deleter()); + external_data.swap(raw_buffer); return Status::OK(); } } @@ -1026,22 +1037,24 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p ORT_RETURN_IF_ERROR( env.ReadFileIntoBuffer(file_path.native().c_str(), offset, length, gsl::make_span(buffer.get(), length))); - deleter = OrtCallback{DeleteCharArray, buffer.get()}; - raw_buffer = buffer.release(); + IAllocatorUniquePtr raw_buffer(buffer.release(), [](void* p) { delete[] reinterpret_cast(p); }); + external_data.swap(raw_buffer); return Status::OK(); } #endif -Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, - SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - Tensor* buffered_tensor, - PrepackedWeightsForGraph* prepacked_info) { - ORT_ENFORCE(utils::HasExternalData(tensor_proto)); +Status GetExtDataFromTensorProto(const Env& env, + const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + OrtValue& ort_value, PrepackedWeightsForGraph* prepacked_info) { + ORT_ENFORCE(HasExternalData(tensor_proto), "TensorProto for: ", + tensor_proto.name(), "Expected to have external data"); + std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); } + std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; @@ -1049,20 +1062,24 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo if (prepacked_info != nullptr) { prepacked_infos.emplace(); } + ORT_RETURN_IF_ERROR( GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, (prepacked_info != nullptr) ? &*prepacked_infos : nullptr)); + TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + MLDataType ml_tensor_type = DataTypeImpl::GetType(); + const auto& name = tensor_proto.name(); + if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data - ext_data_buf = reinterpret_cast(file_offset); - ext_data_len = raw_data_safe_len; - if (buffered_tensor) { - ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast(p); }, - reinterpret_cast(buffered_tensor)}; - } else { - ext_data_deleter = OrtCallback{nullptr, nullptr}; - } + void* ext_data_buf = reinterpret_cast(file_offset); + auto tensor = Tensor{type, tensor_shape, ext_data_buf, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)}; + ORT_RETURN_IF(raw_data_safe_len != tensor.SizeInBytes(), "Weight: ", name, + " kTensorProtoMemoryAddressTag address points to length: ", static_cast(raw_data_safe_len), + " while shape has bytes size: ", tensor.SizeInBytes()); + Tensor::InitOrtValue(std::move(tensor), ort_value); } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, @@ -1071,19 +1088,27 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo " are out of bounds or can not be read in full (>4GB)."); auto buffer = std::make_unique(raw_data_safe_len); - ext_data_deleter = OrtCallback{DeleteCharArray, buffer.get()}; - ext_data_buf = buffer.release(); - ext_data_len = raw_data_safe_len; - ORT_RETURN_IF_ERROR(LoadWebAssemblyExternalData(env, external_data_file_path, file_offset, - ext_data_len, + raw_data_safe_len, ExternalDataLoadType::CPU, - ext_data_buf)); + buffer.get())); + + auto p_tensor = std::make_unique(type, tensor_shape, buffer.get(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + + std::function deleter = [ext_data = buffer.get()](void* t) { + delete reinterpret_cast(t); + delete[] ext_data; + }; + + ort_value.Init(p_tensor.release(), ml_tensor_type, std::move(deleter)); + buffer.release(); + #else - // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to - // manually check file size first. + // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to + // manually check file size first. std::uintmax_t file_length = std::filesystem::file_size(external_data_file_path); SafeInt end_of_read(file_offset); @@ -1092,9 +1117,35 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo "External initializer: ", tensor_proto.name(), " offset: ", file_offset, " size to read: ", static_cast(raw_data_safe_len), " given file_length: ", file_length, " are out of bounds or can not be read in full."); - ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path.c_str(), file_offset, raw_data_safe_len, - ext_data_buf, ext_data_deleter)); - ext_data_len = raw_data_safe_len; + + IAllocatorUniquePtr ext_data_buf; + ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path, file_offset, raw_data_safe_len, + ext_data_buf)); + + // Data on disk is little endian + if constexpr (endian::native != endian::little) { + if (type->Size() > 1) { + gsl::span data_span{reinterpret_cast(ext_data_buf.get()), raw_data_safe_len}; + SwapByteOrderInplace(type->Size(), data_span); + } + } + + auto p_tensor = std::make_unique(type, tensor_shape, ext_data_buf.get(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + ORT_RETURN_IF(raw_data_safe_len != p_tensor->SizeInBytes(), "Weight: ", name, + " External file content has length: ", static_cast(raw_data_safe_len), + " while shape has bytes size: ", p_tensor->SizeInBytes()); + + // Will destroy ext_data as a member of the functor + // can not move the unique_ptr as it is not copyable + std::function deleter = [ext_data = ext_data_buf.get(), + d = ext_data_buf.get_deleter()](void* t) { + delete reinterpret_cast(t); + d(ext_data); + }; + + ort_value.Init(p_tensor.release(), ml_tensor_type, std::move(deleter)); + ext_data_buf.release(); if (prepacked_info != nullptr && !prepacked_infos->empty()) { for (const auto& [key, blobs] : *prepacked_infos) { @@ -1109,12 +1160,11 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo ORT_RETURN_IF(blob_offset < 0 || static_cast(end_of_blob) > file_length, "Pre-packed blob: ", key, " offset: ", blob_offset, " file_length: ", file_length, " is out of bounds and can not read in full"); - void* data_ptr; - OrtCallback data_deleter; - ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path.c_str(), blob_offset, blob_length, - data_ptr, data_deleter)); - IAllocatorUniquePtr data_ptr_unique{data_ptr, OrtCallbackInvoker(data_deleter)}; - prepacked_weights.buffers_.push_back(std::move(data_ptr_unique)); + + IAllocatorUniquePtr data_ptr; + ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path, blob_offset, blob_length, + data_ptr)); + prepacked_weights.buffers_.push_back(std::move(data_ptr)); prepacked_weights.buffer_sizes_.push_back(blob_length); } if (!blobs.empty()) { @@ -1132,7 +1182,7 @@ Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem: const ONNX_NAMESPACE::TensorProto& tensor_proto, const IExternalDataLoader& ext_data_loader, Tensor& tensor) { - ORT_ENFORCE(utils::HasExternalData(tensor_proto)); + ORT_ENFORCE(HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); @@ -1171,9 +1221,26 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor) { // Validate tensor compatibility TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), + [](int64_t dim) { + return dim < 0; + })) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "tensor can't contain negative dims"); + } + + if (HasExternalData(tensor_proto)) { + OrtValue ort_value; + ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, ort_value)); + const auto& ext_tensor = ort_value.Get(); + MakeCpuTensorCopy(ext_tensor, tensor); + return Status::OK(); + } + if (tensor_shape != tensor.Shape()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "TensorProtoToTensor() tensor shape mismatch!"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TensorProtoToTensor() tensor shape mismatch!"); } + const DataTypeImpl* const source_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); if (source_type->Size() > tensor.DataType()->Size()) { @@ -1181,15 +1248,12 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa " can not be written into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); } + // Below we handle the case where TensorProto contains data in itself + // find raw data in proto buf void* raw_data = nullptr; SafeInt raw_data_len = 0; - AutoDelete deleter_for_file_data; - OrtCallback& d = deleter_for_file_data.d; - - if (utils::HasExternalData(tensor_proto)) { - ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d)); - } else if (utils::HasRawData(tensor_proto)) { + if (utils::HasRawData(tensor_proto)) { raw_data = const_cast(tensor_proto.raw_data().data()); // TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data // into a writeable buffer. However, it requires extra memory which may exceed the limit for certain tests. @@ -1200,25 +1264,19 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa raw_data_len = tensor_proto.raw_data().size(); } - if (nullptr != raw_data && utils::IsPrimitiveDataType(source_type)) { + if (nullptr != raw_data && utils::HasString(tensor_proto)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor can not have raw data"); } // unpacking tensor_proto data to preallocated tensor void* preallocated = tensor.MutableDataRaw(); - int64_t tensor_size = 1; - { - for (auto i : tensor_proto.dims()) { - if (i < 0) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims"); - } - tensor_size *= i; - } - } + const int64_t tensor_size = tensor_shape.Size(); + // tensor_size could be zero. see test_slice_start_out_of_bounds\test_data_set_0\output_0.pb - if (static_cast(tensor_size) > SIZE_MAX) { + if (narrow(tensor_size) > SIZE_MAX) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "size overflow"); } + switch (tensor_proto.data_type()) { CASE_PROTO(FLOAT, float); CASE_PROTO(DOUBLE, double); @@ -1256,6 +1314,31 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa return Status::OK(); } +common::Status CreateTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor) { + ORT_RETURN_IF_NOT(utils::HasDataType(tensor_proto), "Initializer must have a datatype"); + auto proto_data_type = tensor_proto.data_type(); + + auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, + CPUAllocator::DefaultInstance()); + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, model_path, tensor_proto, w)); + + tensor = std::move(w); + return Status::OK(); +} + +Status GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) { + if (HasExternalDataInMemory(tensor_proto)) { + result = std::make_unique(); + return TensorProtoWithExternalDataToTensorProto(tensor_proto, {}, *result); + } + + result.reset(); + return Status::OK(); +} + Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, OrtValue& value) { return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, &m, nullptr, value); @@ -1318,15 +1401,7 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, } tensor_proto.set_data_type(tensor.GetElementType()); - if (tensor.IsDataTypeString()) { - auto* mutable_string_data = tensor_proto.mutable_string_data(); - auto f = tensor.Data(); - auto end = f + tensor.Shape().Size(); - for (; f < end; ++f) { - *mutable_string_data->Add() = *f; - } - } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { - // The logic aligns with + if (use_tensor_buffer && tensor.SizeInBytes() > kSmallTensorExternalDataThreshold) { // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302 const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); @@ -1341,7 +1416,16 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, offset, tensor.SizeInBytes(), tensor_proto); } else { - utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); + if (tensor.IsDataTypeString()) { + auto* mutable_string_data = tensor_proto.mutable_string_data(); + auto f = tensor.Data(); + auto end = f + tensor.Shape().Size(); + for (; f < end; ++f) { + *mutable_string_data->Add() = *f; + } + } else { + SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); + } } return tensor_proto; @@ -1413,6 +1497,15 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n return ConstantNodeProtoToTensorProto(node, model_path, tensor, node.output(0)); } +void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { + if (src_tensor.IsDataTypeString()) { + auto src_span = src_tensor.DataAsSpan(); + std::copy(src_span.begin(), src_span.end(), dst_tensor.MutableDataAsSpan().begin()); + } else { + std::memcpy(dst_tensor.MutableDataRaw(), src_tensor.DataRaw(), src_tensor.SizeInBytes()); + } +} + #if !defined(DISABLE_SPARSE_TENSORS) static Status CopySparseData(size_t n_sparse_elements, const ONNX_NAMESPACE::TensorProto& indices, @@ -1847,7 +1940,7 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, // TODO, if std::vector does not use a custom allocator, the default std::allocator will // allocation the memory aligned to std::max_align_t, need look into allocating // forced aligned memory (align as 16 or larger)for unpacked_tensor - if (initializer.data_location() == TensorProto_DataLocation_EXTERNAL) { + if (HasExternalData(initializer)) { ORT_RETURN_IF_ERROR(ReadExternalDataForTensor( initializer, model_path.parent_path(), diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 79eae48c10411..658b8aba3b48a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -40,19 +40,13 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, SafeInt& tensor_byte_size, ExternalDataInfo::PrepackedInfos* prepacked_infos = nullptr); /** - * This function is used to convert the endianess of Tensor data. - * If ext_data_buf is provided, then this buffer content's endianess - * will be changed. + * This function is used to convert the endianess of TensorProto data. + * * Mostly, will be used in big endian system to support the model file * generated on little endian system. * @param tensor_proto given initializer tensor - * @param ext_data_buf optional externl data buffer - * @param ext_data_len optional externl data buffer lengeh - * @returns None */ -void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* tensor_proto, - void* ext_data_buf = NULL, - size_t ext_data_len = 0); +void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto); /** * Wrapper function for set_raw_data. @@ -68,7 +62,7 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, T1* raw_ using namespace ONNX_NAMESPACE; tensor_proto.set_raw_data(raw_data, raw_data_len); if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&tensor_proto); + utils::ConvertRawDataInTensorProto(tensor_proto); } } @@ -102,6 +96,17 @@ TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShape TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto); +/// +/// This function checks if the tensor_proto has external data in memory. +/// If it does, it converts it to a result with data inline, otherwise it does nothing. +/// The function returns a unique_ptr to make it compatible with EPs code. +/// +/// source proto +/// result, can be nullptr if no data in memory, still a success +/// Status +Status GetTensorProtoWithDataIfInMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto, + std::unique_ptr& result); + /** * deserialize a TensorProto into a preallocated memory buffer on CPU. * \param tensor_proto_path A local file path of where the 'input' was loaded from. @@ -137,6 +142,30 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); +/** + * @brief Pre-allocates empty tensor and deserializes a TensorProto into it + * @param env + * @param model_path + * @param tensor_proto source data + * @param tensor destination empty tensor + * @return + */ +common::Status CreateTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + Tensor& tensor); + +/// The threshold for small tensors. If the size of the tensor is LE to this value, +/// The data will stay in the TensorProto. Otherwise, the data will be moved to a Tensor instance +/// and TensorProto will contain a kTensorProtoMemoryAddressTag reference as a result of +/// TensorToTensorProto() below. This is because shape inferencing code in onnx for +/// like Reshape parses weights data and it needs to be in the TensorProto. +/// The value of 127 was chosen empirically to be the smallest value that is required +/// for onnx shape inference to work correctly. The value also takes into account the overhead +/// imposed by having external data. The external data requires location/offset/filename so for +/// small values it is better to keep the data inline in the TensorProto, even if they are not used +/// in shape inferencing, it is cheaper to inline them. +constexpr const size_t kSmallTensorExternalDataThreshold = 127; // 127 bytes + /** * @brief Creates a TensorProto from a Tensor. * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. @@ -173,18 +202,20 @@ address of the memory containing the data. */ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_ADDR_/*"); -// Given a tensor proto with external data obtain a pointer to the data and its length. -// The ext_data_deleter argument is updated with a callback that owns/releases the data. -// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and -// buffered_tensor is not null, buffered_tensor holds the real buffer pointed -// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter -// should release the buffer when tensor_proto is released. +/// +/// Creates a OrtValue with a tensor on top of the external data. +/// If tensor_proto points to a memory address, the OrtValue will be created with a tensor +/// that does not own the memory since the memory is already owned by some other entity. +/// +/// +/// model path +/// tensor proto containing external data +/// output ort value +/// optional pre-packed weight data output container +/// Status common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, - void*& ext_data_buf, SafeInt& ext_data_len, - OrtCallback& ext_data_deleter, - Tensor* buffered_tensor = nullptr, - PrepackedWeightsForGraph* prepacked_for_graph = nullptr); + OrtValue& ort_value, PrepackedWeightsForGraph* prepacked_info = nullptr); // Given a tensor proto with external data obtain a tensor using the specified custom external data loader. common::Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, @@ -207,6 +238,13 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor); +/// +/// Creates a new CPU based tensor and copies the data from the source tensor. +/// +/// +/// +void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor); + #if !defined(DISABLE_SPARSE_TENSORS) // Convert a SparseTensorProto to a dense TensorProto // If the SparseTensorProto contains external data then it loads the data and converts to dense tensor proto @@ -454,6 +492,25 @@ inline bool HasName(const ONNX_NAMESPACE::TypeProto_Opaque& op_proto) { return !op_proto.name().empty(); } +/// +/// Quick check if the this tensor proto has external data in memory. +/// +/// tensor_proto +/// true if ten_proto has external data and it is in memory +bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto); + +/// +/// This function converts TensorProto with external data to TensorProto with inline data. +/// +/// source +/// model_path, can be empty if data is in memory +/// result +/// Status +Status TensorProtoWithExternalDataToTensorProto( + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path, + ONNX_NAMESPACE::TensorProto& new_tensor_proto); + #endif inline bool HasType(const ONNX_NAMESPACE::AttributeProto& at_proto) { diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 25e666ecb2c65..62e73b24cca14 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -4,6 +4,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/function_impl.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" @@ -72,24 +73,11 @@ FunctionImpl::FunctionImpl(onnxruntime::Graph& graph, } for (const auto& input : meta_def->inputs) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input, initializer)) { - // meta_def->inputs could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!function_body_graph_.GetInitializedTensor(input, subgraph_initializer)) { - function_body_graph_.AddInitializedTensor(*initializer); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph, function_body_graph_, input); } for (const auto& constant_initializer : meta_def->constant_initializers) { - const ONNX_NAMESPACE::TensorProto* initializer = graph.GetConstantInitializer(constant_initializer, true); - ORT_ENFORCE(initializer != nullptr, "Initializer " + constant_initializer + " is not found or is not constant initializer."); - // meta_def->constant_initializers could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!function_body_graph_.GetInitializedTensor(constant_initializer, subgraph_initializer)) { - function_body_graph_.AddInitializedTensor(*initializer); - } + graph_utils::MakeConstantInitializerCopyIfNotExist(graph, function_body_graph_, constant_initializer, true); } // TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 8b41460ccce21..5eab61c7b97df 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1249,7 +1249,7 @@ Graph::Graph(const Model& owning_model, if (attrib.type() == AttributeProto_AttributeType_SPARSE_TENSOR) { const TensorProto& sparse_values = node.attribute(0).sparse_tensor().values(); if ((!(sparse_values.has_raw_data())) && tensor->has_raw_data()) { - onnxruntime::utils::ConvertRawDataInTensorProto(tensor); + onnxruntime::utils::ConvertRawDataInTensorProto(*tensor); } } } @@ -1434,20 +1434,18 @@ void Graph::InitializeStateFromModelFileGraphProto() { "Graph state to be loaded into must be empty."); // Name to NodeArg mapping of all graph initializers. - std::unordered_map graph_initializers; - - // Name to NodeArg mapping of all graph inputs. - std::unordered_map graph_inputs; - - // Name to NodeArg mapping of all graph node outputs. - std::unordered_map nodes_outputs; - + InlinedHashMap graph_initializers; + graph_initializers.reserve(graph_proto_->initializer_size()); for (auto& initializer : graph_proto_->initializer()) { auto& initializer_name = initializer.name(); auto initializer_arg = GetNodeArg(initializer_name); graph_initializers.insert({initializer_name, initializer_arg}); } + // Name to NodeArg mapping of all graph inputs. + InlinedHashMap graph_inputs; + graph_inputs.reserve(graph_proto_->input_size()); + // Set graph inputs. // contains inputs exactly specified in proto. // contains inputs without default value (specified as initializer). @@ -1462,6 +1460,9 @@ void Graph::InitializeStateFromModelFileGraphProto() { } } + // Name to NodeArg mapping of all graph node outputs. + InlinedHashMap nodes_outputs; + nodes_outputs.reserve(graph_proto_->node_size() * 2); // rough estimate for (const auto& node : Nodes()) { for (const auto* output_def : node.OutputDefs()) { nodes_outputs.insert({output_def->Name(), output_def}); @@ -3411,18 +3412,28 @@ bool Graph::ResolveContext::IsOuterScopeValue(const std::string& name) const { #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + void Graph::AddInitializedTensor(const TensorProto& tensor) { auto existing = name_to_initial_tensor_.find(tensor.name()); - if (existing != name_to_initial_tensor_.cend()) { + const bool exists = existing != name_to_initial_tensor_.cend(); + if (exists) { ORT_ENFORCE(existing->second == &tensor, "AddInitializedTensor already has tensor with name ", tensor.name(), " but different TensorProto."); return; } + // This overload is used when the tensor does not point to an OrtValue which + // would need to be updated, but it is okay if it is pointing to flatbuffers or some other place at the moment. + // However, if an ort_value present for the name, it must be replaced. + if (utils::HasExternalDataInMemory(tensor)) { + if (ortvalue_initializers_.count(tensor.name()) > 0) { + ORT_THROW("OrtValue needs to be inserted. Use the overload that takes both TensorProto and OrtValue with data"); + } + } const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; name_to_initial_tensor_.emplace(tensor.name(), tensor_added); - SetGraphResolveNeeded(); + if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. // the shape will be set to the correct value in TypeCheckInputsAndInitializers as we don't yet know whether there @@ -3431,6 +3442,45 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { t.mutable_tensor_type()->set_elem_type(tensor.data_type()); ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); } + + SetGraphResolveNeeded(); +} + +Status Graph::AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const OrtValue& ortvalue_initializer) { + ORT_RETURN_IF(name_to_initial_tensor_.count(tensor_proto.name()) > 0, "Attempt to replace the existing tensor"); + + const gsl::not_null tensor_added{graph_proto_->add_initializer()}; + *(tensor_added) = tensor_proto; + name_to_initial_tensor_.emplace(tensor_proto.name(), tensor_added); + + if (ortvalue_initializer.IsAllocated()) { + ORT_RETURN_IF_NOT(utils::HasExternalDataInMemory(tensor_proto), + "TensorProto is expected to refer to the ortvalue_initializer"); + const auto element_type = static_cast(utils::GetTensorElementType(tensor_proto)); + const auto& tensor = ortvalue_initializer.Get(); + ORT_RETURN_IF_NOT(tensor.GetElementType() == element_type, + "Element type mismatch between tensor proto and ortvalue_initializer"); + const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + ORT_RETURN_IF_NOT(proto_shape == tensor.Shape(), "Shape mismatch with ortvalue_initializer"); + + ortvalue_initializers_.insert_or_assign(tensor_proto.name(), ortvalue_initializer); + } else { + ORT_ENFORCE(ortvalue_initializers_.count(tensor_proto.name()) == 0, + "Stray leftover ort_value for a small initializer being inserted."); + } + + SetGraphResolveNeeded(); + if (!is_loaded_from_model_file_ && GetNodeArg(tensor_proto.name()) == nullptr) { + // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. + // the shape will be set to the correct value in TypeCheckInputsAndInitializers as we don't yet know whether there + // will be a matching graph input for this initializer (we prefer shape info from the graph input). + TypeProto t; + t.mutable_tensor_type()->set_elem_type(tensor_proto.data_type()); + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor_proto.name(), &t)); + } + + return Status::OK(); } void Graph::FindAllSubgraphs(std::vector& subgraphs) { @@ -3538,7 +3588,8 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { } #if !defined(ORT_MINIMAL_BUILD) -Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external) { +Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, + OrtValue ort_value, bool must_replace_external) { // name_to_initial_tensor_ maps from name to const TensorProto*, so we first // look up the const pointer by name, then find and modify the mutable // pointed-to TensorProto in graph_proto_. @@ -3557,9 +3608,16 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi return true; }; - ORT_RETURN_IF_NOT(!is_external || utils::HasExternalData(old_initializer), + // This check ensures that we are replacing the right initializer than the users wants to + // replace data that is on disk with a reference to data in memory. + ORT_RETURN_IF_NOT(!must_replace_external || utils::HasExternalData(old_initializer), "Trying to replace non-external initializer with external data"); + // New initializers data generally are within OrtValues + // Small initializers are still stored inside TensorProto + ORT_RETURN_IF_NOT(utils::HasExternalDataInMemory(new_initializer) || !ort_value.IsAllocated(), + "All TensorProtos are expected to point to an OrtValue"); + ORT_RETURN_IF_NOT(dims_eq(), "Replacement tensor's dimensions do not match."); ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(), "Replacement tensor's data type does not match."); @@ -3573,22 +3631,50 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi ORT_ENFORCE(existing_entry != mutable_initializers.pointer_end(), "graph_proto_ is not in sync with name_to_initial_tensor_"); + if (ort_value.IsAllocated()) { + ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.insert_or_assign(initializer_name, std::move(ort_value))); + } else { + ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.erase(initializer_name)); + } + **existing_entry = std::move(new_initializer); return Status::OK(); } -Status Graph::ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer) { - return ReplaceInitializedTensorImpl(std::move(new_initializer), false); +common::Status Graph::ReplaceInitializedTensor(const ONNX_NAMESPACE::TensorProto& new_initializer, + const OrtValue& ort_value) { + return ReplaceInitializedTensorImpl(new_initializer, ort_value, false); } #if !defined(DISABLE_EXTERNAL_INITIALIZERS) Status Graph::InjectExternalInitializedTensors(const InlinedHashMap& external_initializers) { - for (const auto& e : external_initializers) { - const auto& name = e.first; - const OrtValue& ort_value = e.second; - auto tensor_proto = utils::TensorToTensorProto(ort_value.Get(), name); - ORT_RETURN_IF_ERROR(ReplaceInitializedTensorImpl(std::move(tensor_proto), true)); + for (const auto& [name, value] : external_initializers) { + const auto& user_tensor = value.Get(); + + OrtValue ort_value; + TensorProto tensor_proto; + constexpr const bool use_tensor_buffer_true = true; + if (user_tensor.SizeInBytes() > utils::kSmallTensorExternalDataThreshold) { + if (user_tensor.OwnsBuffer()) { + // If the user tensor has its own memory, we avoid copying + tensor_proto = utils::TensorToTensorProto(user_tensor, name, use_tensor_buffer_true); + ORT_ENFORCE(utils::HasExternalDataInMemory(tensor_proto), "Expecting this tensor_proto to have a pointer"); + ort_value = value; + } else { + Tensor initializer{user_tensor.DataType(), user_tensor.Shape(), CPUAllocator::DefaultInstance()}; + utils::MakeCpuTensorCopy(user_tensor, initializer); + + tensor_proto = utils::TensorToTensorProto(initializer, name, use_tensor_buffer_true); + ORT_ENFORCE(utils::HasExternalDataInMemory(tensor_proto), "Expecting this tensor_proto to have a pointer"); + Tensor::InitOrtValue(std::move(initializer), ort_value); + } + } else { + constexpr const bool use_tensor_buffer_false = false; + tensor_proto = utils::TensorToTensorProto(user_tensor, name, use_tensor_buffer_false); + } + + ORT_RETURN_IF_ERROR(ReplaceInitializedTensorImpl(std::move(tensor_proto), std::move(ort_value), true)); LOGS(logger_, INFO) << "Replaced external initializer: " << name; } return Status::OK(); @@ -3598,14 +3684,14 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( const InlinedHashMap>& external_initializer_files) { for (const auto& [tensor_name, tensor_proto] : name_to_initial_tensor_) { if (tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { - std::unique_ptr external_data_info; - ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info)); + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info)); const auto& external_file = external_data_info->GetRelPath(); onnxruntime::FileOffsetType file_offset = external_data_info->GetOffset(); const size_t external_data_length = external_data_info->GetLength(); SafeInt tensor_byte_size; - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size)); + ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size)); ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, "TensorProto: ", tensor_name, " external data size mismatch. Computed size: ", *&tensor_byte_size, ", external_data.length: ", external_data_length); @@ -3641,7 +3727,17 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); auto tensor = Tensor(type, tensor_shape, tensor_buffer, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name); + + constexpr const bool use_tensor_buffer_true = true; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_true); + // Implied that external data is in memory + const bool has_external_data_in_memory = utils::HasExternalData(new_tensor_proto); + + OrtValue ort_value; + if (has_external_data_in_memory) { + Tensor::InitOrtValue(std::move(tensor), ort_value); + } + ortvalue_initializers_.insert_or_assign(tensor_name, std::move(ort_value)); **existing_entry = std::move(new_tensor_proto); } } @@ -3662,14 +3758,24 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro return true; } -bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { - auto it = ortvalue_initializers_.find(name); - if (it == ortvalue_initializers_.end()) { - return false; +bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value, bool check_outer_scope) const { + // We want to make sure that the ort_value is found on the same level as its tensor_proto + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (GetInitializedTensor(name, initializer)) { + auto it = ortvalue_initializers_.find(name); + if (it != ortvalue_initializers_.end()) { + value = it->second; + return true; + } } - value = it->second; - return true; + if (check_outer_scope && IsSubgraph()) { + if (IsOuterScopeValue(name)) { + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + return parent_graph_->GetOrtValueInitializer(name, value, check_outer_scope); + } + } + return false; } void Graph::CleanAllInitializedTensors() noexcept { @@ -3821,12 +3927,6 @@ SaveInputsOutputsToOrtFormat(flatbuffers::FlatBufferBuilder& builder, const std: common::Status Graph::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& fbs_graph) const { - if constexpr (endian::native != endian::little) { - auto& tens = GetAllInitializedTensors(); - for (auto& [name, tensor_p] : tens) { - utils::ConvertRawDataInTensorProto(const_cast(tensor_p)); - } - } auto inputs = SaveInputsOutputsToOrtFormat(builder, graph_inputs_including_initializers_); auto outputs = SaveInputsOutputsToOrtFormat(builder, graph_outputs_); @@ -4111,7 +4211,7 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { output = initializer; // copy any in-memory external data into raw data - if (utils::HasExternalData(initializer)) { + if (utils::HasExternalDataInMemory(initializer)) { const std::filesystem::path ignored; std::basic_string location; onnxruntime::FileOffsetType file_offset; @@ -4119,14 +4219,12 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { ORT_THROW_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size)); - if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { - // file_offset is address - void* data = reinterpret_cast(file_offset); + // file_offset is address + void* data = reinterpret_cast(file_offset); - // set in raw data - output.clear_data_location(); - output.set_raw_data(data, tensor_byte_size); - } + // set in raw data + output.clear_data_location(); + output.set_raw_data(data, tensor_byte_size); } }; diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 199aa79cc1dde..6c27bacacf9c2 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -50,7 +50,16 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, string_data = builder.CreateVectorOfStrings(string_data_vec); } else { std::vector unpacked_tensor; - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, model_path, unpacked_tensor)); + // We can not convert this in place, because the session may be used + // after the model was saved in ort format. If the session is continued to be used, then + // we continue with initializers in memory with wrong endianess + if constexpr (endian::native != endian::little) { + auto be_copy{initializer}; + onnxruntime::utils::ConvertRawDataInTensorProto(be_copy); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(be_copy, model_path, unpacked_tensor)); + } else { + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, model_path, unpacked_tensor)); + } if (external_writer && unpacked_tensor.size() >= kMinimumSizeForExternalData) { // write bytes to external buffer/file and record offset for the start of the data diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index eb0fb22346f37..80bb3f13814d1 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" +#include "core/framework/tensorprotoutils.h" #include "core/graph/graph_proto_serializer.h" namespace onnxruntime { @@ -21,13 +23,13 @@ void GraphViewerToProto(const GraphViewer& graph_view, *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); } - std::unordered_set value_info_ = graph_view.GetValueInfo(); + const auto& value_infos = graph_view.GetValueInfo(); // Reserve memory for the vector to avoid reallocations - std::vector value_info_sorted; - value_info_sorted.reserve(value_info_.size()); + InlinedVector value_info_sorted; + value_info_sorted.reserve(value_infos.size()); + value_info_sorted.assign(value_infos.begin(), value_infos.end()); - value_info_sorted.assign(value_info_.begin(), value_info_.end()); auto sort_predicate = [](const NodeArg* v1, const NodeArg* v2) { return v1->Name() < v2->Name(); }; @@ -58,21 +60,39 @@ void GraphViewerToProto(const GraphViewer& graph_view, } if (include_initializer) { - std::unordered_set current_scope_initializer_set; - - auto& initializers = graph_view.GetAllInitializedTensors(); + const auto& initializers = graph_view.GetAllInitializedTensors(); // Sort initializers to maintain consistency in model proto created across inference requests - std::vector const_inits; - for (auto& it : initializers) { - const_inits.push_back(it.first); + InlinedVector const_inits; + const_inits.reserve(initializers.size()); + for (auto it = initializers.cbegin(), end = initializers.cend(); it != end; ++it) { + const_inits.push_back(it); } - std::sort(const_inits.begin(), const_inits.end()); + std::sort(const_inits.begin(), const_inits.end(), [](const auto& i1, const auto& i2) { + return i1->first < i2->first; + }); + + InlinedHashSet current_scope_initializer_set; + current_scope_initializer_set.reserve(const_inits.size()); + + auto get_initializer_with_data = [&](const ONNX_NAMESPACE::TensorProto& init, + ONNX_NAMESPACE::TensorProto& dest) -> Status { + std::unique_ptr full_init; + ORT_RETURN_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(init, full_init)); + if (full_init) { + dest = std::move(*full_init); + } else { + dest = init; + } + return Status::OK(); + }; - for (auto& it : const_inits) { + // Handle this scope initializers + for (const auto& it : const_inits) { + const auto& [name, init] = *it; + current_scope_initializer_set.insert(name); auto* p_initializer = graph_proto.add_initializer(); - *p_initializer = *(initializers.at(it)); - current_scope_initializer_set.insert(it); + ORT_THROW_IF_ERROR(get_initializer_with_data(*init, *p_initializer)); } // handle outer scope value which is a constant initializer @@ -80,13 +100,15 @@ void GraphViewerToProto(const GraphViewer& graph_view, for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) { const auto& node = graph_view.GetNode(node_idx); for (const auto& input : node->InputDefs()) { - if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { + if (current_scope_initializer_set.count(std::string_view{input->Name()}) > 0) { continue; } - if (graph_view.IsConstantInitializer(input->Name(), true)) { - auto* p_initializer = graph_proto.add_initializer(); - *p_initializer = *(graph_view.GetConstantInitializer(input->Name(), true)); + + const auto* outer_scope_init = graph_view.GetConstantInitializer(input->Name(), true); + if (outer_scope_init != nullptr) { current_scope_initializer_set.insert(input->Name()); + auto* p_initializer = graph_proto.add_initializer(); + ORT_THROW_IF_ERROR(get_initializer_with_data(*outer_scope_init, *p_initializer)); } } } diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index cc48df4444951..dcf627fc605f4 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" + +#include "core/framework/tensorprotoutils.h" #include "core/graph/graph.h" #include "core/common/logging/logging.h" @@ -249,6 +251,19 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s return iter == attrs.end() ? nullptr : &iter->second; } +static NodeArg& GetOrCreateNodeArg(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { + ONNX_NAMESPACE::TypeProto new_type; + auto* typeproto_tensor = new_type.mutable_tensor_type(); + typeproto_tensor->set_elem_type(new_initializer.data_type()); + + auto* shape = typeproto_tensor->mutable_shape(); + for (auto dim : new_initializer.dims()) { + shape->add_dim()->set_dim_value(dim); + } + + return graph.GetOrCreateNodeArg(new_initializer.name(), &new_type); +} + NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { // sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer const ONNX_NAMESPACE::TensorProto* existing = nullptr; @@ -256,17 +271,91 @@ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_ini "Initializer with same name exists. Name:", new_initializer.name()); graph.AddInitializedTensor(new_initializer); + return GetOrCreateNodeArg(graph, new_initializer); +} - ONNX_NAMESPACE::TypeProto new_type; - auto* typeproto_tensor = new_type.mutable_tensor_type(); - typeproto_tensor->set_elem_type(new_initializer.data_type()); +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { + ORT_ENFORCE(!utils::HasExternalData(new_initializer), "Expecting an initializer that contains data inline"); - auto* shape = typeproto_tensor->mutable_shape(); - for (auto dim : new_initializer.dims()) { - shape->add_dim()->set_dim_value(dim); + Tensor tensor; + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(), + new_initializer, tensor)); + auto tensor_proto_with_ptr = utils::TensorToTensorProto(tensor, new_initializer.name(), true); + return AddInitializerWithExternalData(graph, tensor_proto_with_ptr, std::move(tensor)); +} + +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, + Tensor&& tensor) { + OrtValue ort_value; + if (utils::HasExternalDataInMemory(new_initializer)) { + Tensor::InitOrtValue(std::move(tensor), ort_value); } - return graph.GetOrCreateNodeArg(new_initializer.name(), &new_type); + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(new_initializer, ort_value)); + return GetOrCreateNodeArg(graph, new_initializer); +} + +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, + OrtValue ort_value) { + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(new_initializer, ort_value)); + return GetOrCreateNodeArg(graph, new_initializer); +} + +void MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, const std::string& name, + bool copy_in_memory_data) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (src_graph.GetInitializedTensor(name, initializer)) { + // check if the initializer already exists in the destination graph + const ONNX_NAMESPACE::TensorProto* existing = nullptr; + if (!dst_graph.GetInitializedTensor(name, existing)) { + const bool data_in_memory = utils::HasExternalDataInMemory(*initializer); + if (data_in_memory) { + if (copy_in_memory_data) { + ONNX_NAMESPACE::TensorProto tensor_proto; + ORT_THROW_IF_ERROR(utils::TensorProtoWithExternalDataToTensorProto(*initializer, {}, tensor_proto)); + dst_graph.AddInitializedTensor(tensor_proto); + GetOrCreateNodeArg(dst_graph, tensor_proto); + } else { + OrtValue ort_value; + if (src_graph.GetOrtValueInitializer(name, ort_value)) { + // add the initializer to the destination graph + ORT_THROW_IF_ERROR(dst_graph.AddInitializedOrtValue(*initializer, ort_value)); + } else { + // Data may be in memory, but stored in flatbuffers etc. + dst_graph.AddInitializedTensor(*initializer); + } + GetOrCreateNodeArg(dst_graph, *initializer); + } + } else { + dst_graph.AddInitializedTensor(*initializer); + GetOrCreateNodeArg(dst_graph, *initializer); + } + } + } +} + +void MakeConstantInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool check_outer_scope) { + const auto* initializer = src_graph.GetConstantInitializer(name, check_outer_scope); + if (initializer != nullptr) { + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!dst_graph.GetInitializedTensor(name, subgraph_initializer)) { + OrtValue ort_value; + ORT_IGNORE_RETURN_VALUE(src_graph.GetOrtValueInitializer(name, ort_value, check_outer_scope)); + ORT_THROW_IF_ERROR(dst_graph.AddInitializedOrtValue(*initializer, ort_value)); + } + } +} + +Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(name, initializer) && utils::HasExternalDataInMemory(*initializer)) { + ONNX_NAMESPACE::TensorProto tensor_proto; + ORT_THROW_IF_ERROR(utils::TensorProtoWithExternalDataToTensorProto(*initializer, {}, tensor_proto)); + graph.RemoveInitializedTensor(name); + graph.AddInitializedTensor(tensor_proto); + } + return Status::OK(); } int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name) { diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 8710519cdc865..033488d734bd5 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -38,6 +38,71 @@ Checks that new_initializer does not already exist in 'graph' before adding it. */ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer); +/// +/// Adds a new initializer to 'graph' with new_initializer that points to the OrtValue buffer +/// +/// target graph +/// TensorProto with external data contained in ort_value +/// ort_value with data +/// +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, + OrtValue ort_value); + +/** Add a new initializer to 'graph'. + * Checks that new_initializer does not already exist in 'graph' before adding it. + * @param new_initializer tensor proto that has external data pointing to data within the tensor. + * @param tensor with data + * @returns The NodeArg for the new initializer. + * @remarks No matching graph input is created, so the initializer will be constant. + */ +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, Tensor&& tensor); + +/** Add a new initializer to 'graph'. + * The function unpacks data into a tensor and converts new_initializer to a TensorProto with external data in memory. + * The initializer is then added to the graph and tensor is wrapped into OrtValue and added to + * Graph::ortvalue_initializers_; + * + * @param graph The graph to which the initializer will be added. + * @param new_initializer tensor proto that actually has data in it + * @returns The NodeArg for the new initializer. + * @remarks No matching graph input is created, so the initializer will be constant. + */ +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer); + +/// +/// If the initializer with the given name does not exist in the destination graph, but exists in the +/// source graph, copy it to the destination graph. +/// +/// source graph s +/// destination +/// initializers name +/// if external data is in memory, copy data inline. +/// default is false. This is to accomodate EPs who load initializers on their own and do not understand +/// our /*/_ORT_MEM_ADDR_/*/ external data reference +void MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, const std::string& name, + bool copy_in_memory_data = false); + +/// +/// If the constant initializer with the given name does not exist in the destination graph, but exists in the +/// source graph, copy it to the destination graph along with its OrtValue if present. +/// +/// +/// +/// +/// checks outerscope if true +void MakeConstantInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool check_outer_scope); + +/// +/// If the initializer is present with the graph and has external data in memory, +/// convert it to inline data. This is necessary for EPs that can not handle +/// external initializers that are in memory since our in-memory external data is not ONNX standard. +/// +/// Graph +/// intializer name +/// Status +Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name); + /** Gets the index of an output arg with the specified output arg name. */ int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name); diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index f92451cf7fe6d..616bc1257676f 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -69,9 +69,9 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, assert(nullptr != q_tensor); assert(nullptr != k_tensor); assert(nullptr != v_tensor); - Initializer q_initializer(*q_tensor, graph.ModelPath()); - Initializer k_initializer(*k_tensor, graph.ModelPath()); - Initializer v_initializer(*v_tensor, graph.ModelPath()); + Initializer q_initializer(graph, *q_tensor, graph.ModelPath()); + Initializer k_initializer(graph, *k_tensor, graph.ModelPath()); + Initializer v_initializer(graph, *v_tensor, graph.ModelPath()); auto data_type = q_tensor->data_type(); ONNX_NAMESPACE::TensorProto initializer; @@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow(element_count) * sizeof(MLFloat16)); } - return graph_utils::AddInitializer(graph, initializer); + return graph_utils::AddInitializerWithExternalData(graph, initializer); } static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index aa70b347d7b67..ecbb750d0bf19 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -310,7 +310,16 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir // Check that the mask shape is 1x1xWxW auto shape = mask.Shape(); - if (shape == nullptr || static_cast(shape->dim_size()) != 4 || !utils::HasDimValue(shape->dim(0)) || static_cast(1) != shape->dim(0).dim_value() || !utils::HasDimValue(shape->dim(1)) || static_cast(1) != shape->dim(1).dim_value() || !utils::HasDimValue(shape->dim(2)) || !utils::HasDimValue(shape->dim(3)) || shape->dim(2).dim_value() != shape->dim(3).dim_value()) { + if ( + shape == nullptr || + static_cast(shape->dim_size()) != 4 || + !utils::HasDimValue(shape->dim(0)) || + static_cast(1) != shape->dim(0).dim_value() || + !utils::HasDimValue(shape->dim(1)) || + static_cast(1) != shape->dim(1).dim_value() || + !utils::HasDimValue(shape->dim(2)) || + !utils::HasDimValue(shape->dim(3)) || + shape->dim(2).dim_value() != shape->dim(3).dim_value()) { DEBUG_LOG("unidir mask shape not expected"); return false; } @@ -320,28 +329,20 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir return false; } - if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - DEBUG_LOG("This optimizer does not support external data for unidirectional mask right now"); - return false; - } - if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { size_t bytes; if (!utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &bytes).IsOK()) { return false; } - auto data = std::make_unique(bytes); - uint8_t* p = data.get(); - if (!utils::UnpackTensor( - *tensor_proto, - tensor_proto->raw_data().size() ? tensor_proto->raw_data().data() : nullptr, - tensor_proto->raw_data().size(), - p, - bytes) - .IsOK()) { + + std::vector mask_data; + // This takes care of external data in case present + auto status = utils::UnpackInitializerData(*tensor_proto, graph.ModelPath(), mask_data); + if (!status.IsOK()) { + DEBUG_LOG(status.ErrorMessage()); return false; } - std::vector mask_data(p, p + bytes); + if (!ValidateUnidirMask(mask_data, shape->dim(2).dim_value(), is_unidirectional)) { DEBUG_LOG("Mask is neither unidirectional nor all ones"); return false; diff --git a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc index 86a7a4d6afbf8..a98d0ea6f978b 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc @@ -189,7 +189,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph, "total_count: ", total_count, " values.size(): ", values.size()); utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t)); - return &graph_utils::AddInitializer(graph, const_tensor); + return &graph_utils::AddInitializerWithExternalData(graph, const_tensor); } NodeArg* InsertNodesForValidIndices(Graph& graph, diff --git a/onnxruntime/core/optimizer/concat_slice_elimination.cc b/onnxruntime/core/optimizer/concat_slice_elimination.cc index f7a2b3be4466c..b49bcc186e93d 100644 --- a/onnxruntime/core/optimizer/concat_slice_elimination.cc +++ b/onnxruntime/core/optimizer/concat_slice_elimination.cc @@ -86,7 +86,7 @@ static bool GetSliceInfo(const Graph& graph, auto get_initializer_data = [&graph](const ONNX_NAMESPACE::TensorProto* initializer) -> InlinedVector { - Initializer init(*initializer, graph.ModelPath()); + Initializer init(graph, *initializer, graph.ModelPath()); if (initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { int32_t* init_data = init.data(); return InlinedVector(init_data, init_data + init.size()); diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e36eef672c1ed..3d838d8aacfbb 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -95,7 +95,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { ONNX_NAMESPACE::TensorShapeProto result_shape; result_shape.add_dim()->set_dim_value(clamped_slice_length); constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(shape_constant); + graph_utils::AddInitializerWithExternalData(graph, shape_constant); } return is_concrete_shape; // convert to constant if this is true @@ -118,7 +118,7 @@ static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Log } // This is a boolean initializer with a single element. - Initializer condition{*initializer}; + Initializer condition{graph, *initializer}; ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), "' is expected to have a single boolean element"); @@ -317,7 +317,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph. auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx]; const Tensor& out_tensor = ort_value.Get(); - ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(out_tensor, constant_arg_out->Name()); + constexpr const bool use_tensor_buffer_true = true; + ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto( + out_tensor, + constant_arg_out->Name(), + use_tensor_buffer_true); ONNX_NAMESPACE::TensorShapeProto result_shape; for (auto& dim : out_tensor.Shape().GetDims()) { @@ -325,7 +329,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(out_tensorproto); + // The data is too small and has been inlined. + if (!utils::HasExternalData(out_tensorproto)) { + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, OrtValue())); + } else { + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, ort_value)); + } } } } diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index adc1efae5ced4..c349adfccce53 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -62,8 +62,8 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie return Status::OK(); } - Initializer conv_B{*conv_B_tensor_proto, graph.ModelPath()}; - Initializer add_B{*add_B_tensor_proto, graph.ModelPath()}; + Initializer conv_B{graph, *conv_B_tensor_proto, graph.ModelPath()}; + Initializer add_B{graph, *add_B_tensor_proto, graph.ModelPath()}; if (conv_B.size() != add_B.size()) { return Status::OK(); @@ -79,12 +79,14 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie auto new_name = graph.GenerateNodeArgName("ConvAddFusion_B_" + B_input_name); new_conv_B_tensor_proto.set_name(new_name); - NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg); } else { // Create new tensor proto and update shape - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto); + Initializer add_B{graph, *add_B_tensor_proto, graph.ModelPath()}; + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; + add_B.ToProto(new_conv_B_tensor_proto); int64_t dim = conv_W_tensor_proto->dims(0); new_conv_B_tensor_proto.clear_dims(); new_conv_B_tensor_proto.add_dims(dim); @@ -92,7 +94,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie auto new_name = graph.GenerateNodeArgName("ConvAddFusion_Add_B_" + add_B_tensor_proto->name()); new_conv_B_tensor_proto.set_name(new_name); - NodeArg& new_add_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + NodeArg& new_add_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); graph_utils::AddNodeInput(node, 2, new_add_B_node_arg); } diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index 392d03de037cf..8bf5420baddde 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -61,13 +61,13 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff return Status::OK(); } - Initializer bn_scale{*bn_scale_tensor_proto, graph.ModelPath()}; - Initializer bn_B{*bn_B_tensor_proto, graph.ModelPath()}; - Initializer bn_mean{*bn_mean_tensor_proto, graph.ModelPath()}; - Initializer bn_var{*bn_var_tensor_proto, graph.ModelPath()}; - Initializer conv_W{*conv_W_tensor_proto, graph.ModelPath()}; + Initializer bn_scale{graph, *bn_scale_tensor_proto, graph.ModelPath()}; + Initializer bn_B{graph, *bn_B_tensor_proto, graph.ModelPath()}; + Initializer bn_mean{graph, *bn_mean_tensor_proto, graph.ModelPath()}; + Initializer bn_var{graph, *bn_var_tensor_proto, graph.ModelPath()}; + Initializer conv_W{graph, *conv_W_tensor_proto, graph.ModelPath()}; - std::unique_ptr conv_B = nullptr; + std::optional conv_B; const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; if (conv_inputs.size() == 3) { conv_B_tensor_proto = graph_utils::GetConstantInitializer(graph, conv_inputs[2]->Name()); @@ -79,7 +79,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) { return Status::OK(); } - conv_B = std::make_unique(*conv_B_tensor_proto, graph.ModelPath()); + conv_B.emplace(graph, *conv_B_tensor_proto, graph.ModelPath()); } // Calculate new value of initializers of conv node @@ -98,7 +98,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff } // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); + ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto; conv_W.ToProto(new_conv_W_tensor_proto); ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; @@ -120,10 +120,10 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff new_conv_W_tensor_proto.set_name(new_W_name); new_conv_B_tensor_proto.set_name(new_B_name); - NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto); + NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto); graph_utils::ReplaceNodeInput(node, 1, new_conv_W_node_arg); - auto& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + auto& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); if (conv_inputs.size() == 3) { graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg); diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 6da6d089d5a71..dc50a150537f7 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -52,11 +52,11 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef } } - Initializer conv_W{*conv_W_tensor_proto, graph.ModelPath()}; - Initializer mul_B{*mul_B_tensor_proto, graph.ModelPath()}; + Initializer conv_W{graph, *conv_W_tensor_proto, graph.ModelPath()}; + Initializer mul_B{graph, *mul_B_tensor_proto, graph.ModelPath()}; const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; - std::unique_ptr conv_B = nullptr; + std::optional conv_B; const bool is_3d = conv_inputs.size() == 3; if (is_3d) { conv_B_tensor_proto = graph_utils::GetConstantInitializer(graph, conv_inputs[2]->Name()); @@ -68,7 +68,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef return Status::OK(); } - conv_B = std::make_unique(*conv_B_tensor_proto, graph.ModelPath()); + conv_B.emplace(graph, *conv_B_tensor_proto, graph.ModelPath()); } // Calculate new value of initializers of conv node @@ -83,24 +83,24 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef } // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); + ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto; conv_W.ToProto(new_conv_W_tensor_proto); auto new_W_name = graph.GenerateNodeArgName("ConvMulFusion_W_" + conv_W_tensor_proto->name()); new_conv_W_tensor_proto.set_name(new_W_name); // Replace initializers of conv node - NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto); + NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto); graph_utils::ReplaceNodeInput(conv_node, 1, new_conv_W_node_arg); if (is_3d) { - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto); + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; conv_B->ToProto(new_conv_B_tensor_proto); auto new_B_name = graph.GenerateNodeArgName("ConvMulFusion_Mul_B_" + mul_B_tensor_proto->name()); new_conv_B_tensor_proto.set_name(new_B_name); - NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); graph_utils::ReplaceNodeInput(conv_node, 2, new_conv_B_node_arg); } diff --git a/onnxruntime/core/optimizer/div_mul_fusion.cc b/onnxruntime/core/optimizer/div_mul_fusion.cc index 7184e931cb74e..e2cd66fe73f86 100644 --- a/onnxruntime/core/optimizer/div_mul_fusion.cc +++ b/onnxruntime/core/optimizer/div_mul_fusion.cc @@ -40,7 +40,7 @@ bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const } int32_t data_type = initializer->data_type(); - Initializer div_A(*initializer, graph.ModelPath()); + Initializer div_A(graph, *initializer, graph.ModelPath()); if (div_A.size() > 1) { return false; } diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 22b9dca39dceb..1841dfa2791e0 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -46,13 +46,13 @@ static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, template static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) { const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); - Initializer input_init{*input_tensor, graph.ModelPath()}; - ONNX_NAMESPACE::TensorProto new_input_tensor(*input_tensor); + Initializer input_init{graph, *input_tensor, graph.ModelPath()}; + ONNX_NAMESPACE::TensorProto new_input_tensor; input_init.data()[0] = value; input_init.ToProto(new_input_tensor); auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); new_input_tensor.set_name(new_name); - NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); + NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor); graph_utils::ReplaceNodeInput(node, index, new_input); } @@ -79,10 +79,10 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons graph_utils::GetConstantInitializer(graph, node1_zp_name); const ONNX_NAMESPACE::TensorProto* node2_zp_tensor_proto = graph_utils::GetConstantInitializer(graph, node2_zp_name); - Initializer zero_point_init_1{*node1_zp_tensor_proto, graph.ModelPath()}; - Initializer zero_point_init_2{*node2_zp_tensor_proto, graph.ModelPath()}; - Initializer scale_init_1{*node1_scale_tensor_proto, graph.ModelPath()}; - Initializer scale_init_2{*node2_scale_tensor_proto, graph.ModelPath()}; + Initializer zero_point_init_1{graph, *node1_zp_tensor_proto, graph.ModelPath()}; + Initializer zero_point_init_2{graph, *node2_zp_tensor_proto, graph.ModelPath()}; + Initializer scale_init_1{graph, *node1_scale_tensor_proto, graph.ModelPath()}; + Initializer scale_init_2{graph, *node2_scale_tensor_proto, graph.ModelPath()}; if (zero_point_init_1.data_type() != zero_point_init_2.data_type() || scale_init_1.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || scale_init_2.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -181,7 +181,7 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) { } // The Q1 and DQ1 nodes must have equal zero-point and scale values (scalar/constant). - if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath())) { + if (!QDQ::IsQDQPairSupported(graph, *q1, *dq1, get_constant_initializer, graph.ModelPath())) { return false; } @@ -218,7 +218,7 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) { } // The Q2 and DQ2 nodes must have equal zero-point and scale values (scalar/constant). - if (!QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + if (!QDQ::IsQDQPairSupported(graph, *q2, *dq2, get_constant_initializer, graph.ModelPath())) { return false; } diff --git a/onnxruntime/core/optimizer/dropout_elimination.cc b/onnxruntime/core/optimizer/dropout_elimination.cc index b82a944125667..d989c4dd80532 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.cc +++ b/onnxruntime/core/optimizer/dropout_elimination.cc @@ -41,7 +41,7 @@ bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, co return false; } int32_t data_type = initializer->data_type(); - Initializer ratio(*initializer, graph.ModelPath()); + Initializer ratio(graph, *initializer, graph.ModelPath()); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: if (*ratio.data() > 0.f) { diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 103e72072f713..ad25f95ac1186 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -450,7 +450,7 @@ static NodeArg* ExtractEmbedding(Graph& graph, assert(sequence_length > 0); assert(hidden_size > 0); - Initializer old_initializer{*tensor, graph.ModelPath()}; + Initializer old_initializer{graph, *tensor, graph.ModelPath()}; auto data_type = tensor->data_type(); ONNX_NAMESPACE::TensorProto initializer; @@ -474,7 +474,7 @@ static NodeArg* ExtractEmbedding(Graph& graph, utils::SetRawDataInTensorProto(initializer, data, gsl::narrow(element_count) * sizeof(MLFloat16)); } - NodeArg& node_arg = graph_utils::AddInitializer(graph, initializer); + NodeArg& node_arg = graph_utils::AddInitializerWithExternalData(graph, initializer); modified = true; return &node_arg; } diff --git a/onnxruntime/core/optimizer/expand_elimination.cc b/onnxruntime/core/optimizer/expand_elimination.cc index 8aadeb5a1a273..86bf616ea05e2 100644 --- a/onnxruntime/core/optimizer/expand_elimination.cc +++ b/onnxruntime/core/optimizer/expand_elimination.cc @@ -36,12 +36,12 @@ bool ExpandElimination::SatisfyCondition(const Graph& graph, const Node& node, c return false; } - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); + if (initializer.data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { return false; } - const int64_t* target_shapes = initializer->data(); + const int64_t* target_shapes = initializer.data(); // Check the dimensions starting at the trailing dimension. int i = input_shape->dim_size() - 1; diff --git a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc index 1516b07fc2049..388ab14dd51fe 100644 --- a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc +++ b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc @@ -116,27 +116,34 @@ static void FuseInitializerWithNode(Graph& graph, } // Get the src initialized tensor at input def index 0 - auto constant_initializer_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); - ONNX_NAMESPACE::TensorProto src_tensor(*constant_initializer_tensor); + const auto* constant_initializer_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); Initializer src_init{*constant_initializer_tensor, graph.ModelPath()}; - src_init.ToProto(src_tensor); // Convert to dst tensor - ONNX_NAMESPACE::TensorProto dst_tensor; + std::string new_arg_name = graph.GenerateNodeArgName(NewNodeArgName( + next_node.InputDefs()[next_node_arg_index]->Name())); + + OrtValue new_data; if (next_node_arg_type == DataTypeImpl::GetTensorType()) - dst_tensor = src_init.ToFloat32(graph.GenerateNodeArgName(NewNodeArgName(next_node.InputDefs()[next_node_arg_index]->Name())), thread_pool); + new_data = src_init.ToFloat32(thread_pool); else if (next_node_arg_type == DataTypeImpl::GetTensorType()) - dst_tensor = src_init.ToFP16(graph.GenerateNodeArgName(NewNodeArgName(next_node.InputDefs()[next_node_arg_index]->Name()))); + new_data = src_init.ToFP16(); else if (next_node_arg_type == DataTypeImpl::GetTensorType()) - dst_tensor = src_init.ToBFloat16(graph.GenerateNodeArgName(NewNodeArgName(next_node.InputDefs()[next_node_arg_index]->Name()))); + new_data = src_init.ToBFloat16(); else return; // Remove the edge between the current node output def at index 0 and next node arg at relative arg index. graph.RemoveEdge(node.Index(), next_node.Index(), 0, static_cast(next_node_arg_index)); - // Add the new converted Tensor in next node as initializer - graph_utils::ReplaceNodeInput(next_node, static_cast(next_node_arg_index), graph_utils::AddInitializer(graph, dst_tensor)); + // Add the new converted Tensor in next node as initializer potentially with external data + ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get(), new_arg_name, true); + if (!utils::HasExternalData(dst_tensor)) { + new_data = OrtValue(); // Data is inline + } + + auto& new_arg = graph_utils::AddInitializerWithExternalData(graph, dst_tensor, std::move(new_data)); + graph_utils::ReplaceNodeInput(next_node, static_cast(next_node_arg_index), new_arg); } Status FuseInitializersTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 9732ec2587b2a..3cd06350df95d 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -27,7 +27,7 @@ static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_ar if (!optimizer_utils::IsScalar(node_arg)) return false; const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; value = *(init_const.data()); rank = tensor_proto->dims_size(); return true; @@ -256,7 +256,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); axes_initializer_proto.add_int64_data(axis); - NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); + NodeArg* axes_arg = &graph_utils::AddInitializerWithExternalData(graph, axes_initializer_proto); Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", {split_output_arg, axes_arg}, {original_output_arg}); @@ -272,7 +272,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); split_initializer_proto.add_dims(static_cast(split_values.size())); split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); - NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + NodeArg* split_initializer_arg = &graph_utils::AddInitializerWithExternalData(graph, split_initializer_proto); const auto split_node_name = graph.GenerateNodeName(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion"); Node& split_node = graph.AddNode(split_node_name, "Split", "Split for Fused Gather nodes", {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); @@ -359,7 +359,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le unsqueeze_axes_initializer_proto.add_dims(static_cast(1)); unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); unsqueeze_axes_initializer_proto.add_int64_data(static_cast(0)); - NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto); + NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, unsqueeze_axes_initializer_proto); for (size_t i = 0; i < range_input_defs.size(); ++i) { Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", @@ -386,7 +386,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le } else { slice_axes_initializer_proto.add_int32_data(static_cast(axis)); } - NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto); + NodeArg* slice_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, slice_axes_initializer_proto); Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes", {gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1], slice_axes_arg, unsqueeze_outputs[2]}, diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 24f4ad867d101..062cbce6387e6 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -204,8 +204,7 @@ InlinedVector> GenerateTransformers( const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -217,7 +216,7 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); switch (level) { case TransformerLevel::Level1: { @@ -348,8 +347,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors)); + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -471,8 +469,7 @@ InlinedVector> GenerateTransformersForMinimalB const IExecutionProvider& cpu_execution_provider, const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -497,8 +494,7 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors)); + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); @@ -521,7 +517,7 @@ InlinedVector> GenerateTransformersForMinimalB // currently the only level 3 optimizer is the NhwcTransformer which is fully supported at runtime if (!saving) { #ifndef DISABLE_CONTRIB_OPS - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), logger); diff --git a/onnxruntime/core/optimizer/identical_children_consolidation.cc b/onnxruntime/core/optimizer/identical_children_consolidation.cc index bbc8073268f08..4f31db922d078 100644 --- a/onnxruntime/core/optimizer/identical_children_consolidation.cc +++ b/onnxruntime/core/optimizer/identical_children_consolidation.cc @@ -69,7 +69,7 @@ std::string IdenticalChildrenConsolidation::IdentityBuilder(const Graph& graph, if (optimizer_utils::IsScalar(*input_def)) { const auto* data = graph_utils::GetConstantInitializer(graph, name); identity << constant_prefix; - Initializer value{*data, graph.ModelPath()}; + Initializer value{graph, *data, graph.ModelPath()}; switch (static_cast(data->data_type())) { case TensorProto::DataType::TensorProto_DataType_INT8: identity << *value.data(); diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 81eb50286728f..6fbb4177ce90a 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -15,40 +15,82 @@ namespace onnxruntime { +static inline Tensor* GetTensor(OrtValue& ort_value) { + return ort_value.GetMutable(); +} + Initializer::Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, std::string_view name, - gsl::span dims) - : name_(name), - data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, - std::make_shared()) { - if (!data_.IsDataTypeString()) { - memset(data_.MutableDataRaw(), 0, data_.SizeInBytes()); + gsl::span dims) : name_(name) { + auto tensor = Tensor(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, + CPUAllocator::DefaultInstance()); + + if (!tensor.IsDataTypeString()) { + memset(tensor.MutableDataRaw(), 0, tensor.SizeInBytes()); } + + Tensor::InitOrtValue(std::move(tensor), ort_value_); + data_ = GetTensor(ort_value_); } Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& model_path) { - ORT_ENFORCE(utils::HasDataType(tensor_proto), "Initializer must have a datatype"); + ORT_ENFORCE(utils::HasName(tensor_proto), "Initializer must have a name"); + name_ = tensor_proto.name(); + #if !defined(__wasm__) // using full filepath is required by utils::TensorProtoToTensor(). One exception is WebAssembly platform, where // external data is not loaded from real file system. - if (utils::HasExternalData(tensor_proto)) { + if (utils::HasExternalData(tensor_proto) && !utils::HasExternalDataInMemory(tensor_proto)) { ORT_ENFORCE(!model_path.empty(), "model_path must not be empty. Ensure that a path is provided when the model is created or loaded."); } #endif - auto proto_data_type = tensor_proto.data_type(); - if (utils::HasName(tensor_proto)) { - name_ = tensor_proto.name(); + Tensor tensor; + // This creates copy of the data so clients can mutate + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), model_path, tensor_proto, tensor)); + Tensor::InitOrtValue(std::move(tensor), ort_value_); + data_ = GetTensor(ort_value_); +} + +Initializer::Initializer(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path, bool check_outer_scope) { + ORT_ENFORCE(utils::HasName(tensor_proto), "Initializer must have a name"); + name_ = tensor_proto.name(); + + // Check if the data is in memory. This does not mean, though, that the data is in the ort_value + if (utils::HasExternalDataInMemory(tensor_proto)) { + OrtValue ort_value; + if (graph.GetOrtValueInitializer(name_, ort_value, check_outer_scope)) { + const auto& src_tensor = ort_value.Get(); + // We need to make a copy of the data to ensure that the original data is not mutated + // This is generally inline with TensorProtoToTensor() behavior which copies data from + // TensorProto to Tensor. + Tensor initializer{src_tensor.DataType(), src_tensor.Shape(), CPUAllocator::DefaultInstance()}; + utils::MakeCpuTensorCopy(src_tensor, initializer); + Tensor::InitOrtValue(std::move(initializer), ort_value_); + data_ = GetTensor(ort_value_); + return; + } +#if !defined(__wasm__) + ORT_ENFORCE(!model_path.empty(), + "model_path must not be empty. Ensure that a path is provided when the model is created or loaded."); +#endif } - auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + Tensor tensor; + // Creates a copy of the data from tensor_proto + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), model_path, tensor_proto, tensor)); + Tensor::InitOrtValue(std::move(tensor), ort_value_); + data_ = GetTensor(ort_value_); +} + +Initializer::~Initializer() = default; - // This must be pre-allocated - Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, - std::make_shared()); - ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path, tensor_proto, w)); - data_ = std::move(w); +void Initializer::ToProtoWithOrtValue(ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) const { + constexpr const bool use_tensor_buffer_true = true; + tensor_proto = utils::TensorToTensorProto(*data_, name_, use_tensor_buffer_true); + ort_value = ort_value_; } #if !defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -90,6 +132,26 @@ struct TensorToProtoFP16 { } }; +template +struct TensorToFP16 { + void operator()(const Tensor& data, Tensor& dst) const { + ToFp16 to_fp16; + auto span = data.DataAsSpan(); + auto* dst_data = dst.MutableData(); + for (const auto& v : span) { + *dst_data++ = MLFloat16::FromBits(to_fp16(v)); + } + } +}; + +template <> +struct TensorToFP16 { + void operator()(const Tensor& data, Tensor& dst) const { + const auto count = narrow(data.Shape().Size()); + MlasConvertFloatToHalfBuffer(data.Data(), dst.MutableData(), count); + } +}; + template struct ToBFloat16; @@ -127,6 +189,18 @@ struct TensorToProtoBFloat16 { } }; +template +struct TensorToBFloat16 { + void operator()(const Tensor& data, Tensor& dst) const { + ToBFloat16 to_bfloat16; + auto span = data.DataAsSpan(); + auto* dst_data = dst.MutableData(); + for (const auto& v : span) { + *dst_data++ = BFloat16::FromBits(to_bfloat16(v)); + } + } +}; + template struct ToFloat32; @@ -159,27 +233,24 @@ struct ToFloat32 { }; template -struct TensorToProtoFloat32 { - void operator()(const Tensor& data, ONNX_NAMESPACE::TensorProto& proto, onnxruntime::concurrency::ThreadPool* /*thread_pool*/) const { - auto span = data.DataAsSpan(); +struct TensorToFloat32 { + void operator()(const Tensor& src, Tensor& dst, onnxruntime::concurrency::ThreadPool* /*thread_pool*/) const { + auto src_span = src.DataAsSpan(); + auto* dst_data = dst.MutableData(); ToFloat32 to_float32; - for (const auto& v : span) { - proto.add_float_data(to_float32(v)); + for (const auto& v : src_span) { + *dst_data++ = to_float32(v); } } }; template <> -struct TensorToProtoFloat32 { +struct TensorToFloat32 { void operator()(const Tensor& data, - ONNX_NAMESPACE::TensorProto& proto, + Tensor& dst, onnxruntime::concurrency::ThreadPool* thread_pool) const { - auto source = reinterpret_cast(data.DataRaw()); - auto count = size_t(data.SizeInBytes() / sizeof(MLFloat16)); - auto destination_mem = std::make_unique(count); - auto destination = destination_mem.get(); - MlasConvertHalfToFloatBufferInParallel(source, destination, count, thread_pool); - utils::SetRawDataInTensorProto(proto, destination, count * sizeof(float)); + const auto count = narrow(data.Shape().Size()); + MlasConvertHalfToFloatBufferInParallel(data.Data(), dst.MutableData(), count, thread_pool); } }; @@ -199,26 +270,54 @@ inline void SetNameDims(const std::string& name, ONNX_NAMESPACE::TensorProto Initializer::ToFP16(const std::string& name) const { ONNX_NAMESPACE::TensorProto tensor_proto; - SetNameDims(name, data_.Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, tensor_proto); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, tensor_proto); + SetNameDims(name, data_->Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, tensor_proto); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, tensor_proto); return tensor_proto; } ONNX_NAMESPACE::TensorProto Initializer::ToBFloat16(const std::string& name) const { ONNX_NAMESPACE::TensorProto tensor_proto; - SetNameDims(name, data_.Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, tensor_proto); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, tensor_proto); + SetNameDims(name, data_->Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, tensor_proto); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, tensor_proto); return tensor_proto; } -ONNX_NAMESPACE::TensorProto Initializer::ToFloat32(const std::string& name, onnxruntime::concurrency::ThreadPool* thread_pool) const { - ONNX_NAMESPACE::TensorProto tensor_proto; - SetNameDims(name, data_.Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT, tensor_proto); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, tensor_proto, thread_pool); - return tensor_proto; +OrtValue onnxruntime::Initializer::ToFP16() const { + if (data_->IsDataType()) { + return ort_value_; + } + OrtValue result; + auto tensor = Tensor(DataTypeImpl::GetType(), data_->Shape().GetDims(), CPUAllocator::DefaultInstance()); + Tensor::InitOrtValue(std::move(tensor), result); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *result.GetMutable()); + return result; +} + +OrtValue Initializer::ToBFloat16() const { + if (data_->IsDataType()) { + return ort_value_; + } + OrtValue result; + auto tensor = Tensor(DataTypeImpl::GetType(), data_->Shape().GetDims(), CPUAllocator::DefaultInstance()); + Tensor::InitOrtValue(std::move(tensor), result); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *result.GetMutable()); + return result; +} + +OrtValue Initializer::ToFloat32(onnxruntime::concurrency::ThreadPool* thread_pool) const { + if (data_->IsDataType()) { + return ort_value_; + } + OrtValue result; + auto tensor = Tensor(DataTypeImpl::GetType(), data_->Shape().GetDims(), CPUAllocator::DefaultInstance()); + Tensor::InitOrtValue(std::move(tensor), result); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *result.GetMutable(), thread_pool); + return result; } namespace { @@ -314,46 +413,46 @@ struct ElementWiseDiv : OpElementWise::typ } // namespace Initializer& Initializer::add(float value) { - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, value); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, value); return *this; } Initializer& Initializer::add(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::sub(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::mul(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::div(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::sqrt() { - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_); return *this; } @@ -395,13 +494,13 @@ struct ScaleByAxis { void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); - const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); + const size_t block_size = narrow(data_->Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; ORT_ENFORCE(scalers.size() == 1 || (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index 17d1ada29d778..96c2ca41f5539 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -11,10 +11,11 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/framework/allocator.h" -#include "core/optimizer/graph_transformer.h" +#include "core/framework/ort_value.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" +#include "core/optimizer/graph_transformer.h" #include "core/util/math.h" namespace onnxruntime { @@ -29,50 +30,82 @@ class Initializer final { Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& model_path = {}); - ~Initializer() = default; + Initializer(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path = {}, bool check_outer_scope = false); + ~Initializer(); + + /// + /// This function creates a new tensor_proto with a complete copy of the data + /// + /// output void ToProto(ONNX_NAMESPACE::TensorProto& tensor_proto) const { - tensor_proto = utils::TensorToTensorProto(data_, name_); + tensor_proto = utils::TensorToTensorProto(*data_, name_); } + + /// + /// This function creates a pair of TensorProto and OrtValue. Unless the data + /// is short, tensor_proto will be a reference to the data in OrtValue. + /// Useful when adding a new initializer to the graph with external data in memory. + /// + /// + /// + void ToProtoWithOrtValue(ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) const; + #if !defined(ORT_EXTENDED_MINIMAL_BUILD) + // XXX: Below two used only in training, convert to OrtValue result ONNX_NAMESPACE::TensorProto ToFP16(const std::string& name) const; - ONNX_NAMESPACE::TensorProto ToBFloat16(const std::string& name) const; - ONNX_NAMESPACE::TensorProto ToFloat32(const std::string& name, onnxruntime::concurrency::ThreadPool* thread_pool = nullptr) const; + OrtValue ToFP16() const; + OrtValue ToBFloat16() const; + OrtValue ToFloat32(onnxruntime::concurrency::ThreadPool* thread_pool = nullptr) const; + #endif // ORT_EXTENDED_MINIMAL_BUILD int data_type() const { - return data_.GetElementType(); + return data_->GetElementType(); } - std::string_view name() const { + const std::string& name() const { return name_; } template T* data() { - return data_.MutableData(); + return data_->MutableData(); } template const T* data() const { - return data_.Data(); + return data_->Data(); + } + + const void* data_raw() const { + return data_->DataRaw(); + } + + void* mutable_data_raw() { + return data_->MutableDataRaw(); } template auto DataAsSpan() const { - return data_.DataAsSpan(); + return data_->DataAsSpan(); } gsl::span DataAsByteSpan() const { - return gsl::make_span(reinterpret_cast(data_.DataRaw()), data_.SizeInBytes()); + return gsl::make_span(reinterpret_cast(data_->DataRaw()), data_->SizeInBytes()); + } + + gsl::span MutableDataAsByteSpan() { + return gsl::make_span(reinterpret_cast(data_->MutableDataRaw()), data_->SizeInBytes()); } gsl::span dims() const { - return data_.Shape().GetDims(); + return data_->Shape().GetDims(); } - size_t size() const { return narrow(data_.Shape().Size()); } + size_t size() const { return narrow(data_->Shape().Size()); } #if !defined(ORT_EXTENDED_MINIMAL_BUILD) Initializer& add(float value); @@ -91,7 +124,8 @@ class Initializer final { #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; - Tensor data_; + OrtValue ort_value_; + Tensor* data_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 3f19fb46e5ade..1e88ed44b1a8a 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -70,7 +70,7 @@ static std::vector GetAxesFromReduceMeanNode(Node& reduce_mean_node, co const auto* axes = reduce_mean_node.InputDefs()[1]; const auto* axes_const = graph.GetConstantInitializer(axes->Name(), true); if (axes_const != nullptr) { - Initializer initializer{*axes_const, graph.ModelPath()}; + Initializer initializer{graph, *axes_const, graph.ModelPath()}; auto span_axes = initializer.DataAsSpan(); axes_values.insert(axes_values.end(), span_axes.begin(), span_axes.end()); } @@ -480,7 +480,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add2_node.MutableInputDefs()[1]->Name()); if (tensor_proto != nullptr && tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - Initializer initializer{*tensor_proto, graph.ModelPath()}; + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; layer_norm_node.AddAttribute("epsilon", initializer.data()[0]); } else { layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON); @@ -727,7 +727,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add_node.MutableInputDefs()[1]->Name()); if (tensor_proto != nullptr && tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - Initializer initializer{*tensor_proto, graph.ModelPath()}; + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; layer_norm_node.AddAttribute("epsilon", initializer.data()[0]); } else { layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON); diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index d02efe9890f1c..a6c422e59aeef 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -188,7 +188,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, shape_initializer_proto.add_dims(static_cast(shape.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape.size() * sizeof(int64_t)); - NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); ONNX_NAMESPACE::TypeProto new_arg_type; const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type()); diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 6b76dc626fba0..725cb3fc33f04 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -193,11 +193,11 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& * temp = scale / sqrt(var + epsilon) * output = (temp * Input) - ((temp * mean) + bias) */ - Initializer scale(*scale_tensor, graph.ModelPath()); - Initializer bias(*bias_tensor, graph.ModelPath()); - Initializer mean(*mean_tensor, graph.ModelPath()); - Initializer var(*var_tensor, graph.ModelPath()); - Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + Initializer scale(graph, *scale_tensor, graph.ModelPath()); + Initializer bias(graph, *bias_tensor, graph.ModelPath()); + Initializer mean(graph, *mean_tensor, graph.ModelPath()); + Initializer var(graph, *var_tensor, graph.ModelPath()); + Initializer matmul_b(graph, *matmul_b_tensor, graph.ModelPath()); var.add(epsilon); var.sqrt(); @@ -208,18 +208,18 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& bias.sub(mean); // create B tensorProto for new Gemm node from initializer. - ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor; matmul_b.ToProto(new_gemm_b_tensor); const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); new_gemm_b_tensor.set_name(new_gemm_b_name); - NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_b_tensor); // create bias tensorProto for new Gemm node from initializer. - ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor; bias.ToProto(new_gemm_bias_tensor); const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); new_gemm_bias_tensor.set_name(new_gemm_bias_name); - NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_bias_tensor); Node& gemm_node = graph.AddNode( graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 46f306b92bed5..335209dbfadaf 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -408,7 +408,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { // Reuse the existing NodeArg. nchwc_conv_W_arg = filters_it->second; } else { - Initializer conv_W{*conv_W_tensor_proto, graph_.ModelPath()}; + Initializer conv_W{graph_, *conv_W_tensor_proto, graph_.ModelPath()}; const auto conv_W_dims = conv_W.dims(); int64_t reordered_filter_size = nchwc_output_channels * filter_input_channels; @@ -437,7 +437,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]); } - nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); + nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto); filters_map->emplace(input_defs[1], nchwc_conv_W_arg); } @@ -449,7 +449,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { // Reuse the existing NodeArg. nchwc_conv_B_arg = biases_it->second; } else { - Initializer conv_B{*conv_B_tensor_proto, graph_.ModelPath()}; + Initializer conv_B{graph_, *conv_B_tensor_proto, graph_.ModelPath()}; InlinedVector aligned_bias(gsl::narrow(nchwc_output_channels)); ORT_ENFORCE(output_channels <= nchwc_output_channels, "Buffer overflow"); @@ -464,7 +464,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_B_tensor_proto.add_dims(nchwc_output_channels); - nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); + nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto); aligned_biases_.emplace(input_defs[2], nchwc_conv_B_arg); } } @@ -580,7 +580,7 @@ Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg, } shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims); - shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto); + shape_arg = &graph_utils::AddInitializerWithExternalData(graph_, shape_tensor_proto); } Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"), @@ -863,10 +863,10 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { return; } - Initializer bn_scale{*bn_scale_tensor_proto, graph_.ModelPath()}; - Initializer bn_B{*bn_B_tensor_proto, graph_.ModelPath()}; - Initializer bn_mean{*bn_mean_tensor_proto, graph_.ModelPath()}; - Initializer bn_var{*bn_var_tensor_proto, graph_.ModelPath()}; + Initializer bn_scale{graph_, *bn_scale_tensor_proto, graph_.ModelPath()}; + Initializer bn_B{graph_, *bn_B_tensor_proto, graph_.ModelPath()}; + Initializer bn_mean{graph_, *bn_mean_tensor_proto, graph_.ModelPath()}; + Initializer bn_var{graph_, *bn_var_tensor_proto, graph_.ModelPath()}; // Calculate the scale and bias for the replacement convolution. bn_var.add(epsilon); @@ -892,7 +892,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { nchwc_conv_W_tensor_proto.add_dims(1); nchwc_conv_W_tensor_proto.add_dims(1); - auto* nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); + auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto); std::copy_n(bn_B.data(), channels, padded_buffer.data()); @@ -903,7 +903,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { gsl::narrow(nchwc_channels) * sizeof(float)); nchwc_conv_B_tensor_proto.add_dims(nchwc_channels); - auto* nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); + auto* nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto); // Create the replacement node. std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_bn_nchwc"); @@ -1045,7 +1045,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) { return; } - Initializer sizes{*sizes_tensor_proto, graph_.ModelPath()}; + Initializer sizes{graph_, *sizes_tensor_proto, graph_.ModelPath()}; auto* sizes_data = sizes.data(); // The sizes data can only be used if the input shape is static and the @@ -1075,7 +1075,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) { return; } - Initializer scales{*scales_tensor_proto, graph_.ModelPath()}; + Initializer scales{graph_, *scales_tensor_proto, graph_.ModelPath()}; auto* scales_data = scales.data(); // Cast the scales to integers and verify that the scales are positive and diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index bba39b698a27a..6dafd9cd97799 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -68,7 +68,7 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con op_type == "Mul" || op_type == "Div") { int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); + Initializer add_init(graph, *initializer, graph.ModelPath()); float value = 0.0f; switch (data_type) { diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index b2e8e491c361c..8c26d7a9ce209 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -7,7 +7,6 @@ #include "core/common/logging/logging.h" #include "core/common/logging/macros.h" #include "core/common/status.h" -#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/data_types.h" #include "core/framework/fuse_nodes_funcs.h" @@ -37,7 +36,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, : execution_provider_(execution_provider), is_sparse_initializer_func_(is_sparse_initializer_func), logger_(logger) { - allocator_ptr_ = std::make_shared(); + allocator_ptr_ = CPUAllocator::DefaultInstance(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); @@ -86,7 +85,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, : execution_provider_(execution_provider), is_sparse_initializer_func_(is_sparse_initializer_func), logger_(logger) { - allocator_ptr_ = std::make_shared(); + allocator_ptr_ = CPUAllocator::DefaultInstance(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index 24a23312feba9..feb51514c8b2d 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -13,7 +13,6 @@ #include "core/framework/execution_frame.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/ort_value.h" -#include "core/framework/callback.h" namespace onnxruntime { class DataTransferManager; diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index bfd32a384335d..d0b6d42fd46c9 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -117,7 +117,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log // constant_value should be zero because Conv and MaxPool allow only 0 as padding value. if (node.InputDefs().size() > 2) { const auto* pad_constant_value_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); - Initializer pad_constant_value{*pad_constant_value_proto, graph.ModelPath()}; + Initializer pad_constant_value{graph, *pad_constant_value_proto, graph.ModelPath()}; if (std::any_of(pad_constant_value.DataAsByteSpan().begin(), pad_constant_value.DataAsByteSpan().end(), [](const uint8_t byte) { return byte != 0; })) { return false; } @@ -152,7 +152,7 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef if (pad_node.SinceVersion() >= 11) { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, pad_node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; pads_values.assign(pads.DataAsSpan().begin(), pads.DataAsSpan().end()); } else { pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end()); diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 5538aa54801cc..42cd31b5bd7b4 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -96,10 +96,10 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log } bool should_convert = false; - Initializer w_temp(*weight_tensor_proto, graph.ModelPath()); + Initializer w_temp(graph, *weight_tensor_proto, graph.ModelPath()); { int8_t* p = w_temp.data(); - for (size_t i = 0; i < w_temp.size(); i++) { + for (size_t i = 0, lim = w_temp.size(); i < lim; i++) { if (*p < -64 || *p > 64) { should_convert = true; } @@ -108,10 +108,10 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log } } - Initializer r_temp(*r_tensor_proto, graph.ModelPath()); + Initializer r_temp(graph, *r_tensor_proto, graph.ModelPath()); { int8_t* p = r_temp.data(); - for (size_t i = 0; i < r_temp.size(); i++) { + for (size_t i = 0, lim = r_temp.size(); i < lim; i++) { if (*p < -64 || *p > 64) { should_convert = true; } @@ -130,22 +130,22 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8"); weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims()); utils::SetRawDataInTensorProto(weights_proto_u8, w_temp.data(), static_cast(w_temp.size())); - input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); + input_defs[w_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8); ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; QDQ::Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true); - input_defs[w_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8); + input_defs[w_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8); ONNX_NAMESPACE::TensorProto r_proto_u8; r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8"); r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims()); utils::SetRawDataInTensorProto(r_proto_u8, r_temp.data(), static_cast(r_temp.size())); - input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8); + input_defs[r_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_proto_u8); ONNX_NAMESPACE::TensorProto r_zp_proto_u8; QDQ::Int8TensorProto2Uint8(r_zp_tensor_proto, r_zp_proto_u8, graph, true); - input_defs[r_zp_idx] = &graph_utils::AddInitializer(graph, r_zp_proto_u8); + input_defs[r_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_zp_proto_u8); return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index 72ca1cb74f1fd..a1859b9d7071b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -30,7 +30,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& return false; } - Initializer s_initializer(*s_tensor_proto, graph.ModelPath()); + Initializer s_initializer(graph, *s_tensor_proto, graph.ModelPath()); if (s_initializer.dims().size() != 0 || s_initializer.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { return false; @@ -45,7 +45,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& return false; } - Initializer zp_initializer(*zp_tensor_proto, graph.ModelPath()); + Initializer zp_initializer(graph, *zp_tensor_proto, graph.ModelPath()); if (zp_initializer.dims().size() != 0) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc index 507bc71709b2f..691cf1183eb0e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc @@ -51,7 +51,8 @@ bool CleanUpNodeSequence(NodeSequence node_sequence_type, Graph& graph, NodeInde const auto output_edges_count = second_node_ptr->GetOutputEdgesCount(); if (!match_second(*second_node_ptr) || - !QDQ::IsQDQPairSupported(first_node, *second_node_ptr, get_constant_initializer, graph.ModelPath(), false) || + !QDQ::IsQDQPairSupported(graph, first_node, *second_node_ptr, get_constant_initializer, + graph.ModelPath(), false) || (produces_graph_output && output_edges_count != 0) || (!produces_graph_output && output_edges_count != 1)) { return false; diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc index f2033dcbc1b03..98c818b0c761b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc @@ -41,8 +41,8 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) { // TODO(fuchen): need to augment this when we support per row quantization using ONNX_TENSOR_ELEM_TYPE = ONNX_NAMESPACE::TensorProto::DataType; - Initializer q_zero_point(*q_zp_tensor_proto, graph.ModelPath()); - Initializer dq_zero_point(*dq_zp_tensor_proto, graph.ModelPath()); + Initializer q_zero_point(graph, *q_zp_tensor_proto, graph.ModelPath()); + Initializer dq_zero_point(graph, *dq_zp_tensor_proto, graph.ModelPath()); if (q_zero_point.size() != 1 || dq_zero_point.size() != 1 || q_zero_point.data_type() != ONNX_TENSOR_ELEM_TYPE::TensorProto_DataType_INT8 || @@ -61,7 +61,7 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) { zp_tensor_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); zp_tensor_proto_u8.set_name(graph.GenerateNodeArgName("qdq_s8_to_u8_zp_conversion")); utils::SetRawDataInTensorProto(zp_tensor_proto_u8, &q_zp_value, sizeof(uint8_t)); - NodeArg* zp_u8_arg = &graph_utils::AddInitializer(graph, zp_tensor_proto_u8); + NodeArg* zp_u8_arg = &graph_utils::AddInitializerWithExternalData(graph, zp_tensor_proto_u8); auto q_output_node_arg_name = graph.GenerateNodeArgName("qdq_s8_to_u8_quant"); NodeArg* q_output_arg = &graph.GetOrCreateNodeArg(q_output_node_arg_name, nullptr); diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index fe5874d067b95..3ecdbf0ede6b3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -15,6 +15,7 @@ namespace onnxruntime::QDQ { bool IsQDQPairSupported( + const Graph& graph, const Node& q_node, const Node& dq_node, const GetConstantInitializerFn& get_const_initializer, const std::filesystem::path& model_path, @@ -56,10 +57,10 @@ bool IsQDQPairSupported( } // check Q/DQ have same scale and zero point - Initializer q_zp(*q_zp_tensor_proto, model_path); - Initializer q_scale(*q_scale_tensor_proto, model_path); - Initializer dq_zp(*dq_zp_tensor_proto, model_path); - Initializer dq_scale(*dq_scale_tensor_proto, model_path); + Initializer q_zp(graph, *q_zp_tensor_proto, model_path); + Initializer q_scale(graph, *q_scale_tensor_proto, model_path); + Initializer dq_zp(graph, *dq_zp_tensor_proto, model_path); + Initializer dq_scale(graph, *dq_scale_tensor_proto, model_path); if (q_zp.data_type() != dq_zp.data_type() || q_scale.data_type() != dq_scale.data_type() || @@ -84,6 +85,7 @@ bool IsQDQPairSupported( } bool IsDQQConversion( + const Graph& graph, const Node& dq_node, const Node& q_node, const GetConstantInitializerFn& get_const_initializer, const std::filesystem::path& model_path) { @@ -118,10 +120,10 @@ bool IsDQQConversion( } // check Q/DQ have same scale type and different zero point type - Initializer q_zp(*q_zp_tensor_proto, model_path); - Initializer q_scale(*q_scale_tensor_proto, model_path); - Initializer dq_zp(*dq_zp_tensor_proto, model_path); - Initializer dq_scale(*dq_scale_tensor_proto, model_path); + Initializer q_zp(graph, *q_zp_tensor_proto, model_path); + Initializer q_scale(graph, *q_scale_tensor_proto, model_path); + Initializer dq_zp(graph, *dq_zp_tensor_proto, model_path); + Initializer dq_scale(graph, *dq_scale_tensor_proto, model_path); return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); } @@ -167,6 +169,7 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( } bool IsQOrDQScalePositiveConstantScalar( + const Graph& graph, const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, const std::filesystem::path& model_path) { auto q_or_dq_input_defs = q_or_dq_node.InputDefs(); @@ -183,7 +186,7 @@ bool IsQOrDQScalePositiveConstantScalar( return false; } - Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path); + Initializer q_or_dq_scale(graph, *q_or_dq_scale_tensor_proto, model_path); switch (q_or_dq_scale.data_type()) { case ONNX_NAMESPACE::TensorProto::FLOAT: @@ -250,7 +253,7 @@ bool GetQScalarScaleZp(const Graph& graph, const Node& q_node, float& scale, int } // Support scalar float scale only for now. Need to extend to other float types if needed. - Initializer scale_initializer(*scale_tensor_proto, graph.ModelPath()); + Initializer scale_initializer(graph, *scale_tensor_proto, graph.ModelPath()); if (scale_initializer.dims().size() != 0 || scale_initializer.data_type() != ONNX_NAMESPACE::TensorProto::FLOAT) { return false; } @@ -275,7 +278,7 @@ bool GetQScalarScaleZp(const Graph& graph, const Node& q_node, float& scale, int return false; } - Initializer zp_initializer(*zp_tensor_proto, graph.ModelPath()); + Initializer zp_initializer(graph, *zp_tensor_proto, graph.ModelPath()); if (zp_initializer.dims().size() != 0) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 25bd557b799c6..0648a3fc1f188 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -36,6 +36,7 @@ using GetConstantInitializerFn = std::function()[0] != -128) || diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc index f094f3c199f2a..616144c0ccde0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc @@ -43,12 +43,12 @@ bool ConvertS8WeightToU8(Graph& graph, Node& op_node, // The weights fits into S7, overflow is not a problem, no need to convert to U8 return false; } - input_defs[weights_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); + input_defs[weights_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8); // Convert weight zero point to uint8 ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true); - input_defs[weight_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8); + input_defs[weight_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8); return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h index 1c1341fe5a127..a96f088c48306 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h +++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h @@ -47,7 +47,7 @@ inline bool Int8TensorProto2Uint8( // principle. A better solution is to provide an efficient const iterator for // TensorProto. This require coordination with onnx side. - Initializer temp(*src, graph.ModelPath()); + Initializer temp(graph, *src, graph.ModelPath()); int8_t* p = temp.data(); bool should_convert = false; for (size_t i = 0; i < temp.size(); i++) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 8f99b7409d4fe..dce69e2913582 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -197,7 +197,8 @@ void SetOptionalZeroPoint::UpdateNodes(Graph& graph, const NodesToOptimize& sele const ONNX_NAMESPACE::TensorProto* dummy_zp_tensor_proto; if (!graph.GetInitializedTensor(zp_tensor_proto.name(), dummy_zp_tensor_proto)) { - graph.AddInitializedTensor(zp_tensor_proto); + // Zero points are small, no need for external data + graph_utils::AddInitializer(graph, zp_tensor_proto); } auto& node_arg = graph.GetOrCreateNodeArg(zp_tensor_proto.name(), nullptr); @@ -280,8 +281,7 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) + concurrency::ThreadPool* intra_op_thread_pool) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -291,8 +291,7 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool}, - p_buffered_tensors_{p_buffered_tensors} { + intra_op_thread_pool_{intra_op_thread_pool} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -317,7 +316,6 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { - ORT_RETURN_IF_NOT(p_buffered_tensors_, "Buffered tensors map cannot be null"); const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -325,11 +323,16 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const auto& attrs = dq_node->GetAttributes(); const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), + "Missing required weight: ", weight_arg->Name(), " for node: ", dq_node->Name()); + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), + "Missing required scale: ", scale_arg->Name(), " for node: ", dq_node->Name()); const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; - graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); - graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); if (zp_arg) { + // zero point is optional, one can have a NodeArg for a missing optional + // if the name is an empty string, and the below would not return ptr to a proto. graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); } @@ -343,37 +346,38 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // external file, a raw buffer, or a repeated field depending on the data // type. UnpackTensor() already contains some of these logic and is closest // to what we need. But it does not handle external data. - Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); - Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); - std::optional zp_src_ptr; - auto cpu_allocator = std::make_shared(); + + std::optional zp_src; + auto cpu_allocator = CPUAllocator::DefaultInstance(); auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); - auto weight_dst_ptr = std::make_unique(uint8_type, - TensorShape{N, quant_num, blob_bytes}, - cpu_allocator); + auto weight_dst = Tensor(uint8_type, + TensorShape{N, quant_num, blob_bytes}, + cpu_allocator); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); auto scale_size = (TensorShape{N, quant_num}).Size(); - auto scale_dst_ptr = std::make_unique(scale_type, - TensorShape{scale_size}, - cpu_allocator); + auto scale_dst = Tensor(scale_type, + TensorShape{scale_size}, + cpu_allocator); std::string zp_dst_name; - std::unique_ptr zp_dst_ptr; + std::optional zp_dst; auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); if (zp_tensor_proto) { - zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); - zp_dst_ptr = std::make_unique(uint8_type, - TensorShape{zp_size}, - cpu_allocator); + zp_dst = Tensor(uint8_type, + TensorShape{zp_size}, + cpu_allocator); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); - zp_dst_ptr = std::make_unique(uint8_type, - TensorShape{zp_size}, - cpu_allocator); - memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes()); + zp_dst = Tensor(uint8_type, + TensorShape{zp_size}, + cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -381,10 +385,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -394,10 +398,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -409,10 +413,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -423,10 +427,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -435,43 +439,24 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - auto weight_T_tp = utils::TensorToTensorProto(*weight_dst_ptr, weight_dst_name, true); - auto scale_T_tp = utils::TensorToTensorProto(*scale_dst_ptr, scale_dst_name, true); + auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); std::optional zp_T_tp; - if (zp_dst_ptr) { - zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst_ptr, zp_dst_name, true)); + if (zp_dst) { + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, weight_T_tp, std::move(weight_dst))); replacement_node.MutableInputArgsCount().push_back(1); - if (weight_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - // If tensor is too small, tensor proto directly copies data from tensor. The tensor allocated - // here can be directly destructed. - // Only keep the tensor in p_buffered_tensors_ when the tensor proto is using external data location - // and pointing the location to tensor's buffer. - ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(weight_dst_name, std::move(weight_dst_ptr)).second, - "Failed to add buffered tensor ", - weight_dst_name); - } - input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, scale_T_tp, std::move(scale_dst))); replacement_node.MutableInputArgsCount().push_back(1); - if (scale_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(scale_dst_name, std::move(scale_dst_ptr)).second, - "Failed to add buffered tensor ", - scale_dst_name); - } if (zp_T_tp) { - input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); + input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, zp_T_tp.value(), std::move(*zp_dst))); replacement_node.MutableInputArgsCount().push_back(1); - if (zp_T_tp->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(zp_dst_name, std::move(zp_dst_ptr)).second, - "Failed to add buffered tensor ", - zp_dst_name); - } } return Status::OK(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index d25077ca4b491..02a8353707599 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -86,8 +86,7 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors); + concurrency::ThreadPool* intra_op_thread_pool); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -106,7 +105,6 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; - std::unordered_map>* p_buffered_tensors_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index ae89af1f256d1..93eb33628105c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -282,8 +282,7 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + concurrency::ThreadPool* intra_op_thread_pool) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. @@ -291,8 +290,7 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors); + intra_op_thread_pool); #if !defined(ORT_MINIMAL_BUILD) std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider}; @@ -353,8 +351,7 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { SelectorActionRegistry CreateSelectorActionRegistry( bool is_int8_allowed, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -368,8 +365,7 @@ SelectorActionRegistry CreateSelectorActionRegistry( WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors); + intra_op_thread_pool); return qdq_selector_action_registry; } @@ -380,12 +376,11 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, p_buffered_tensors), + intra_op_thread_pool), apply_context, // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 627ddd35b9919..dce1cd44fd3ea 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -29,8 +29,7 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr, - std::unordered_map>* p_buffered_tensors = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 255714054cdaa..dbcf1af6c2080 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -173,12 +173,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node if (!allow_nonpositive_scale_) { // IsQDQPairSupported will check that the scale is the same between q_node and dq_node. - if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) { + if (!IsQOrDQScalePositiveConstantScalar(graph_viewer.GetGraph(), q_node, get_const_initializer, + graph_viewer.ModelPath())) { return false; } } - return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); + return IsQDQPairSupported(graph_viewer.GetGraph(), q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, @@ -345,7 +346,7 @@ bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } if (req_equal_quant_params_ && - !IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { + !IsQDQPairSupported(graph_viewer.GetGraph(), q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { return false; } } @@ -761,7 +762,7 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n return graph_viewer.GetConstantInitializer(initializer_name, true); }; - return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); + return IsQDQPairSupported(graph_viewer.GetGraph(), q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } bool CumSumNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 58e90ea3c71c2..aa6f9c5409de7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -117,28 +117,28 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph NodeArg* weight_scale_arg = nullptr; if (!dq_1) { - auto initializer = std::make_unique(*weight_proto, graph.ModelPath()); - const float* weight_data = initializer->data(); + Initializer initializer(graph, *weight_proto, graph.ModelPath()); + const float* weight_data = initializer.data(); // Quantize float32 weight to int8_t (per-tensor, symmetric). // int8_t quantization of input[1] works with input[0] of all types. float scale; int8_t zp; - GetQuantizationParameter(weight_data, static_cast(initializer->size()), scale, zp, nullptr); + GetQuantizationParameter(weight_data, static_cast(initializer.size()), scale, zp, nullptr); // Weight scale initializer. ONNX_NAMESPACE::TensorProto weight_scale_proto; weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale")); weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); weight_scale_proto.mutable_float_data()->Add(scale); - weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto); + weight_scale_arg = &graph_utils::AddInitializerWithExternalData(graph, weight_scale_proto); // Weight zero point initializer. ONNX_NAMESPACE::TensorProto weight_zp_proto; weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp")); weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); weight_zp_proto.mutable_int32_data()->Add(static_cast(zp)); - NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto); + NodeArg& weight_zp_arg = graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto); // Q from float32 to int8. ONNX_NAMESPACE::TypeProto weight_q_type_proto; diff --git a/onnxruntime/core/optimizer/quick_gelu_fusion.cc b/onnxruntime/core/optimizer/quick_gelu_fusion.cc index b09ef1c460b8e..54236c9a27980 100644 --- a/onnxruntime/core/optimizer/quick_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/quick_gelu_fusion.cc @@ -37,7 +37,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (!optimizer_utils::IsScalar(input_arg)) continue; const TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); if (!tensor_proto) continue; - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == TensorProto_DataType_FLOAT) { alpha = *(init_const.data()); diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index ae12c7bdfd4ac..efd7022ab764b 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -56,7 +56,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff data_type = initializer->data_type(); // construct an initializer to gracefully handle typed or raw data in the TensorProto - Initializer i(*initializer, graph.ModelPath()); + Initializer i(graph, *initializer, graph.ModelPath()); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: if (*i.data() < 0.f) { @@ -97,12 +97,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff mutable_next_node->AddAttribute("min", 0.f); } else { // Add the initialized tensor to the graph - graph.AddInitializedTensor(replacement_min); - - // Create a corresponding NodeArg for the initialized tensor - ONNX_NAMESPACE::TypeProto t; - t.mutable_tensor_type()->set_elem_type(replacement_min.data_type()); - NodeArg* replacement_min_nodearg = &graph.GetOrCreateNodeArg(replacement_min.name(), &t); + auto* replacement_min_nodearg = &graph_utils::AddInitializerWithExternalData(graph, replacement_min); // Replace the input def at the appropriate index of the Clip node auto& mutable_input_defs = mutable_next_node->MutableInputDefs(); diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 324905f953eec..36213609f6b61 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -438,7 +438,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); - auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto); + auto& new_node_arg = graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); // Safely remove concat parent nodes which have only one output for (int i = 0; i < concat_input_count; ++i) { @@ -492,7 +492,7 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); - NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name, {contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg}, {contiguous_reshapes.back().get().MutableOutputDefs()[0]}); diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index d3ec2dd459fd3..c4066097e43f1 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -61,7 +61,7 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, cons auto get_initializer_data = [&graph](const ONNX_NAMESPACE::TensorProto* initializer) -> InlinedVector { - Initializer init(*initializer, graph.ModelPath()); + Initializer init(graph, *initializer, graph.ModelPath()); if (initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { int32_t* init_data = init.data(); return InlinedVector(init_data, init_data + init.size()); diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc index 5c09e5225ab9c..74121508132dc 100644 --- a/onnxruntime/core/optimizer/stft_decomposition.cc +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -46,7 +46,7 @@ NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[T proto.add_dims(shape[i]); } utils::SetRawDataInTensorProto(proto, begin, element_count * sizeof(TDataType)); - return &graph_utils::AddInitializer(graph, proto); + return &graph_utils::AddInitializerWithExternalData(graph, proto); } template diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 2aa3cf30813b6..a320de2ee7a13 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -5,7 +5,9 @@ #include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" +#include "core/graph/graph_utils.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -371,26 +373,31 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker const InitializedTensorSet& initializers_consumed, const logging::Logger& logger) { std::map replacements; - for (const auto& pair : initializers_consumed) { - const auto& name = pair.first; + for (const auto& [name, tensor_proto] : initializers_consumed) { const onnxruntime::NodeArg* provider_def = FindNodeArg(provider_input_defs_, name); const onnxruntime::NodeArg* non_provider_def = FindNodeArg(non_provider_input_defs_, name); if (provider_def != nullptr && non_provider_def != nullptr) { std::string new_def_name = graph_.GenerateNodeArgName(name); auto& new_def = graph_.GetOrCreateNodeArg(new_def_name, provider_def->TypeAsProto()); - // We make a copy of the initializer that is to be consumed by the provider Node so that - // session state initializer can copy it over to the provider device during its operation - // TODO: The copy being made is possibly redundant if this occurs in a subgraph - // When multiple subgraphs consume the same initializer as an implicit input, - // multiple copies of the initializer will be made into the provider device - // This should not directly affect runtime performance as the copies occur during initialization - // but overuse of the provider device's memory is definitely inefficient - // In future, we need to "statefully" make the copy only once and use it in all subgraphs referencing the initializer - const TensorProto* tensor_proto = pair.second; TensorProto new_tensor_proto = *tensor_proto; *(new_tensor_proto.mutable_name()) = new_def_name; - graph_.AddInitializedTensor(new_tensor_proto); + + // Query any OrtValue existing for the original initializer + // We are checking outer scope because GetInitializer is called with true, therefore, we potentially + // have references to parent graphs. + // We are doing this so the same OrtValue is re-used in subgraphs and no copies made for big items. + constexpr const bool check_outer_scope_true = true; + OrtValue ort_value; + // The initializer can be in memory with OrtValue or it can be a flatbuffer mapped. + if (utils::HasExternalDataInMemory(new_tensor_proto) && + graph_.GetOrtValueInitializer(name, ort_value, check_outer_scope_true)) { + // Re-use the same ort_value and proto that points to the same buffer + ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializerWithExternalData(graph_, new_tensor_proto, + std::move(ort_value))); + } else { + ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializer(graph_, new_tensor_proto)); + } replacements.insert(std::make_pair(provider_def, &new_def)); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index f87df746234fa..48ea54434b805 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -14,6 +14,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/optimizer/transpose_optimization/ort_transpose_optimization.h" @@ -558,8 +559,8 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector new_tensor_shape_dims; - std::vector permutations; + TensorShapeVector new_tensor_shape_dims; + InlinedVector permutations; permutations.reserve(perm.size()); new_tensor_shape_dims.reserve(perm.size()); for (int64_t p : perm) { @@ -568,12 +569,12 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector& shape) { @@ -607,14 +610,19 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectordata_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { const float* val = init_const.data(); @@ -110,7 +110,7 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { const int64_t* val = init_const.data(); @@ -171,7 +171,7 @@ bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, I return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { const int64_t* val = init_const.data(); @@ -333,7 +333,7 @@ bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, flo bool is_constant = true; const ONNX_NAMESPACE::TensorProto* initializer = graph.GetConstantInitializer(input->Name(), true); if (initializer) { - Initializer i(*initializer, graph.ModelPath()); + Initializer i(graph, *initializer, graph.ModelPath()); switch (initializer->data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: value = *i.data(); @@ -421,7 +421,7 @@ bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntim return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const T* val = init_const.data(); value = *val; diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 7dbc3fe82db47..e100d3626f76b 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -26,7 +26,6 @@ limitations under the License. #include "core/common/common.h" #include "core/common/path_string.h" -#include "core/framework/callback.h" #include "core/platform/env_time.h" #include "core/platform/telemetry.h" #include "core/session/onnxruntime_c_api.h" @@ -179,7 +178,7 @@ class Env { virtual common::Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, gsl::span buffer) const = 0; - using MappedMemoryPtr = std::unique_ptr; + using MappedMemoryPtr = std::unique_ptr>; /** * Maps the content of the file into memory. diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 94aadf3df4d7e..0e43d054d5c5e 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -62,15 +62,8 @@ namespace { constexpr int OneMillion = 1000000; -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - int ret = munmap(p->addr, p->len); +static void UnmapFile(void* addr, size_t len) noexcept { + int ret = munmap(addr, len); if (ret != 0) { auto [err_no, err_msg] = GetErrnoInfo(); LOGS_DEFAULT(ERROR) << "munmap failed. error code: " << err_no << " error msg: " << err_msg; @@ -451,7 +444,9 @@ class PosixEnv : public Env { mapped_memory = MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, - OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}}; + [mapped_base, mapped_length](void*) { + UnmapFile(mapped_base, mapped_length); + }}; return Status::OK(); } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 9fdd323b365d6..36c6b54a1fce0 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -45,15 +45,8 @@ EXTERN_C IMAGE_DOS_HEADER __ImageBase; namespace onnxruntime { -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - bool ret = UnmapViewOfFile(p->addr); +static void UnmapFile(void* addr) noexcept { + bool ret = UnmapViewOfFile(addr); if (!ret) { const auto error_code = GetLastError(); LOGS_DEFAULT(ERROR) << "unmap view of file failed. error code: " << error_code @@ -467,9 +460,12 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); GSL_SUPPRESS(r.11) + mapped_memory = MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, - OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}}; + [mapped_base](void*) { + UnmapFile(mapped_base); + }}; return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index a4609eb2a0584..fb3d3c80ec372 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -59,8 +59,8 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co NodeAttrHelper helper(node); if (input_defs.size() > 1 && input_defs[1]->Exists()) { - auto& axes_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); - Initializer axes_initializer(axes_tensor); + const auto& axes_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + Initializer axes_initializer(model_builder.GetGraphViewer().GetGraph(), axes_tensor); int64_t* data = axes_initializer.data(); int64_t size = axes_initializer.size(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index b35d6971623ed..e3781ed7d388b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -44,7 +44,7 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& data_name = input_defs[0]->Name(); const auto& new_shape_name = input_defs[1]->Name(); - Initializer unpacked_tensor(*model_builder.GetConstantInitializer(new_shape_name)); + Initializer unpacked_tensor(model_builder.GetGraphViewer().GetGraph(), *model_builder.GetConstantInitializer(new_shape_name)); TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan()); // ReshapeHelper applies the ONNX rules to create the concrete output shape @@ -75,7 +75,8 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, +bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, + const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& new_shape_name = input_defs[1]->Name(); @@ -87,7 +88,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP return false; } - Initializer unpacked_tensor(*new_shape_tensor); + Initializer unpacked_tensor(input_params.graph_viewer.GetGraph(), *new_shape_tensor); auto new_shape = unpacked_tensor.DataAsSpan(); if (new_shape.empty()) { LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 837573003e515..9b1545035104c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -77,7 +77,8 @@ bool GetValidatedResizeScales(const GraphViewer& graph_viewer, return false; } - Initializer unpacked_tensor(*scales_tensor); + const auto& graph = graph_viewer.GetGraph(); + Initializer unpacked_tensor(graph, *scales_tensor, graph.ModelPath()); auto scales_data = unpacked_tensor.DataAsSpan(); scales.assign(scales_data.begin(), scales_data.end()); @@ -108,7 +109,7 @@ bool GetValidatedResizeSizes(const GraphViewer& graph_viewer, return false; } - Initializer unpacked_tensor(*sizes_tensor); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *sizes_tensor, graph_viewer.ModelPath()); auto sizes_data = unpacked_tensor.DataAsSpan(); sizes.assign(sizes_data.begin(), sizes_data.end()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index bf72fbbf1ace4..1a0f4e4de2e09 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -59,7 +59,7 @@ Status PrepareSliceComputeMetadata(const Node& slice_node, const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name()); ORT_RETURN_IF_NOT(tensor_proto, "Failed to get constant initializer."); - Initializer unpacked_tensor(*tensor_proto, graph_viewer.ModelPath()); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *tensor_proto, graph_viewer.ModelPath()); const auto data_type = unpacked_tensor.data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { auto tensor_data = unpacked_tensor.DataAsSpan(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 717d344982473..4ee9b54cebd16 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -63,7 +63,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (input_defs.size() > 1) { // if "split" is explicitly provided as an input - Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + const auto& const_init = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + Initializer unpacked_tensor(const_init); auto split_span = unpacked_tensor.DataAsSpan(); AddOperationInput(*split_op, "split_sizes", model_builder.AddConstant(split_op->type(), "split_sizes", split_span)); @@ -102,7 +103,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (input_defs.size() > 1) { // if "split" is explicitly provided as an input // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + const auto& const_init = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + Initializer unpacked_tensor(model_builder.GetGraphViewer().GetGraph(), const_init); auto split_span = unpacked_tensor.DataAsSpan(); for (const auto& split_size : split_span) { coreml_splitnd->add_splitsizes(split_size); @@ -164,7 +166,8 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar return false; } - Initializer unpacked_tensor(*splits_tensor); + Initializer unpacked_tensor(input_params.graph_viewer.GetGraph(), *splits_tensor, + input_params.graph_viewer.ModelPath()); auto splits_span = unpacked_tensor.DataAsSpan(); int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0}); if (sum_of_splits != split_dims_at_axis) { diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index d4fea3c5a75c7..9043593e5fc9e 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -395,7 +395,7 @@ static Status DoTransposeInt4(const gsl::span& permutations, const "Expected to transpose int4 tensor"); // Convert to Tensor, transpose, and then repack back to Tensor. - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); Tensor input_unpacked; Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index 6adcfec852690..313e9ea4b9948 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -441,6 +441,8 @@ struct SliceIterator : public SliceIteratorBase { }; inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) { + ORT_ENFORCE(src->SizeInBytes() == tgt->SizeInBytes(), "Destination size does not match source."); + void* target = tgt->MutableDataRaw(); const void* source = src->DataRaw(); diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc index 79ec47b8f8443..b3886234fa238 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc @@ -62,8 +62,8 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE return Status::OK(); } - Initializer BatchNormalization_B{*BatchNormalization_B_tensor_proto, graph.ModelPath()}; - Initializer add_B{*add_B_tensor_proto, graph.ModelPath()}; + Initializer BatchNormalization_B{graph, *BatchNormalization_B_tensor_proto, graph.ModelPath()}; + Initializer add_B{graph, *add_B_tensor_proto, graph.ModelPath()}; if (BatchNormalization_B.size() != add_B.size()) { return Status::OK(); @@ -73,11 +73,12 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE // Create new initializers of BatchNormalization ONNX_NAMESPACE::TensorProto new_BatchNormalization_B_tensor_proto; - BatchNormalization_B.ToProto(new_BatchNormalization_B_tensor_proto); + OrtValue ort_value; + BatchNormalization_B.ToProtoWithOrtValue(new_BatchNormalization_B_tensor_proto, ort_value); // Replace initializers of BatchNormalization node graph.RemoveInitializedTensor(BatchNormalization_inputs[2]->Name()); - graph.AddInitializedTensor(new_BatchNormalization_B_tensor_proto); + ORT_RETURN_IF_ERROR(graph.AddInitializedOrtValue(new_BatchNormalization_B_tensor_proto, ort_value)); // Remove Add node. auto* add_node_to_remove = graph.GetNode(add_node.Index()); diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc index 02f16b4d3d467..21c85c7f67d30 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc @@ -13,7 +13,8 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const { +Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, + const onnxruntime::logging::Logger&) const { auto& BatchNormalization_node = node; const auto& mul_node = *BatchNormalization_node.OutputNodesBegin(); const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs(); @@ -54,8 +55,8 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE } } - Initializer BatchNormalization_Scale{*BatchNormalization_Scale_tensor_proto, graph.ModelPath()}; - Initializer mul_B{*mul_B_tensor_proto, graph.ModelPath()}; + Initializer BatchNormalization_Scale{graph, *BatchNormalization_Scale_tensor_proto, graph.ModelPath()}; + Initializer mul_B{graph, *mul_B_tensor_proto, graph.ModelPath()}; const ONNX_NAMESPACE::TensorProto* BatchNormalization_B_tensor_proto = nullptr; if (!graph.GetInitializedTensor(BatchNormalization_inputs[2]->Name(), BatchNormalization_B_tensor_proto)) @@ -67,7 +68,7 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE BatchNormalization_B_tensor_proto->dims_size() != 1) { return Status::OK(); } - Initializer BatchNormalization_B{*BatchNormalization_B_tensor_proto, graph.ModelPath()}; + Initializer BatchNormalization_B{graph, *BatchNormalization_B_tensor_proto, graph.ModelPath()}; // Calculate new value of initializers of BatchNormalization node BatchNormalization_Scale.scale_by_axis(mul_B, 1); @@ -79,17 +80,20 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE } // Create new initializers of BatchNormalization - ONNX_NAMESPACE::TensorProto new_BatchNormalization_Scale_tensor_proto(*BatchNormalization_Scale_tensor_proto); - BatchNormalization_Scale.ToProto(new_BatchNormalization_Scale_tensor_proto); + ONNX_NAMESPACE::TensorProto new_BatchNormalization_Scale_tensor_proto; + OrtValue ort_value_scale; + BatchNormalization_Scale.ToProtoWithOrtValue(new_BatchNormalization_Scale_tensor_proto, ort_value_scale); // Replace initializers of BatchNormalization node graph.RemoveInitializedTensor(BatchNormalization_inputs[1]->Name()); - graph.AddInitializedTensor(new_BatchNormalization_Scale_tensor_proto); + ORT_RETURN_IF_ERROR(graph.AddInitializedOrtValue(new_BatchNormalization_Scale_tensor_proto, ort_value_scale)); + + ONNX_NAMESPACE::TensorProto new_BatchNormalization_B_tensor_proto; + OrtValue ort_value_B_scale; + BatchNormalization_B.ToProtoWithOrtValue(new_BatchNormalization_B_tensor_proto, ort_value_B_scale); - ONNX_NAMESPACE::TensorProto new_BatchNormalization_B_tensor_proto(*BatchNormalization_B_tensor_proto); - BatchNormalization_B.ToProto(new_BatchNormalization_B_tensor_proto); graph.RemoveInitializedTensor(BatchNormalization_inputs[2]->Name()); - graph.AddInitializedTensor(new_BatchNormalization_B_tensor_proto); + ORT_RETURN_IF_ERROR(graph.AddInitializedOrtValue(new_BatchNormalization_B_tensor_proto, ort_value_B_scale)); // Remove Mul node. auto* mul_node_to_remove = graph.GetNode(mul_node.Index()); diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 01f44e91fd49c..bb5d942ecb14a 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -672,30 +672,35 @@ bool DnnlErfNodeCapability::Supported(const Node* node, const GraphViewer& graph return true; } -bool DnnlErfNodeCapability::IsInitilizedWithExpectedValue(const GraphViewer& graph_viewer, const NodeArg* node_arg, float expected_value) const { - // TypeAsProto()->tensor_type().elem_type() - if ((ORT_DataType)node_arg->TypeAsProto()->tensor_type().elem_type() == type_float32) { - const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; - graph_viewer.GetInitializedTensor(node_arg->Name(), tensor_proto); - const float* val = reinterpret_cast(tensor_proto->raw_data().data()); - - // Check for NaN and Inf - if (std::isnan(val[0]) || std::isinf(val[0])) { - if (std::isinf(val[0]) && std::isinf(expected_value) && (std::signbit(val[0]) == std::signbit(expected_value))) { - return true; - } - return false; - } +bool DnnlErfNodeCapability::IsInitilizedWithExpectedValue(const GraphViewer& graph_viewer, const NodeArg* node_arg, + float expected_value) const { + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (!graph_viewer.GetInitializedTensor(node_arg->Name(), tensor_proto)) { + return false; + } - const float atol = 1e-8f; - const float rtol = 1e-5f; - float diff = std::abs(val[0] - expected_value); - if (diff > (atol + rtol * std::abs(expected_value))) { - return false; + onnxruntime::Initializer erf_weight{graph_viewer.GetGraph(), *tensor_proto, graph_viewer.ModelPath()}; + if (erf_weight.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return false; + } + + const float* val = erf_weight.data(); + + // Check for NaN and Inf + if (std::isnan(val[0]) || std::isinf(val[0])) { + if (std::isinf(val[0]) && std::isinf(expected_value) && (std::signbit(val[0]) == std::signbit(expected_value))) { + return true; } - return true; + return false; } - return false; + + const float atol = 1e-8f; + const float rtol = 1e-5f; + float diff = std::abs(val[0] - expected_value); + if (diff > (atol + rtol * std::abs(expected_value))) { + return false; + } + return true; } const Node* DnnlErfNodeCapability::FirstParentByType(const Node& node, const std::string& parent_type) const { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc index f6497e381d0f7..fdebe51865f4b 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc @@ -181,13 +181,11 @@ bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(const onnxruntime::Grap return false; } - if (!tensor_proto->has_raw_data()) { - return false; - } - const auto data_type = input_arg.Type(); if (data_type == dnnl::memory::data_type::f32) { - const float* val = reinterpret_cast(tensor_proto->raw_data().data()); + onnxruntime::Initializer initializer(onnx_subgraph_viewer.GetGraph(), + *tensor_proto, onnx_subgraph_viewer.ModelPath()); + const float* val = initializer.data(); if (std::isnan(val[0]) || std::isinf(val[0])) { if (std::isinf(val[0]) && std::isinf(expected_value) && (std::signbit(val[0]) == std::signbit(expected_value))) { return true; @@ -775,9 +773,8 @@ void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const o // check if b_zp is all zeros, assume data is s8 since only s8 weight is supported in onednn bool all_zero = true; - std::vector unpacked_tensor; - unpacked_tensor.resize(num_elements, 1); - ORT_THROW_IF_ERROR(onnxruntime::utils::UnpackTensor(*tensor_proto, tensor_proto->has_raw_data() ? tensor_proto->raw_data().data() : nullptr, tensor_proto->has_raw_data() ? tensor_proto->raw_data().size() : 0, reinterpret_cast(unpacked_tensor.data()), num_elements)); + std::vector unpacked_tensor; + ORT_THROW_IF_ERROR(onnxruntime::utils::UnpackInitializerData(*tensor_proto, unpacked_tensor)); for (const auto& val : unpacked_tensor) { if (val != 0) { all_zero = false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 5108f90fc763a..c37b068d988a4 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -203,7 +203,7 @@ common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); }; - Initializer unpacked_tensor(*s, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *s, model_path); // The scale should be one or more floats scale = unpacked_tensor.DataAsSpan()[0]; } @@ -215,7 +215,7 @@ common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); }; - Initializer unpacked_tensor(*zp, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *zp, model_path); // Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI zero_point = static_cast(unpacked_tensor.DataAsByteSpan()[0]); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc index 8127de0a0f05f..83727f7c9d960 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc @@ -80,7 +80,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No const auto* pads_initializer = model_builder.GetConstantInitializer(pads); ORT_RETURN_IF_NOT(pads_initializer, "pads must be a constant"); - Initializer pads_initializer_raw_data(*pads_initializer); + Initializer pads_initializer_raw_data(model_builder.GetGraphViewer().GetGraph(), *pads_initializer); // assume pads_initializer has int64 data, per ONNX spec std::vector converted_pads_data{}; converted_pads_data.reserve(2 * data_rank); @@ -102,7 +102,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No const auto& constant_value = inputs[2].node_arg.Name(); const auto* constant_value_initializer = model_builder.GetConstantInitializer(constant_value); ORT_RETURN_IF_NOT(constant_value_initializer, "constant_value must be a constant"); - Initializer pad_value_raw_data_init(*constant_value_initializer); + Initializer pad_value_raw_data_init(model_builder.GetGraphViewer().GetGraph(), *constant_value_initializer); pad_value = pad_value_raw_data_init.DataAsSpan()[0]; } @@ -158,7 +158,7 @@ bool PadOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node return false; } - Initializer unpacked_tensor(*pads_initializer); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *pads_initializer); auto tensor_data = unpacked_tensor.DataAsSpan(); for (size_t i = 0; i < unpacked_tensor.size(); i++) { if (tensor_data[i] < 0) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index af5aeba6c8236..c4f1e5f402491 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -249,7 +249,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N return false; } - const Initializer unpacked_tensor(*scales); + const Initializer unpacked_tensor(graph_viewer.GetGraph(), *scales); auto scales_data = unpacked_tensor.DataAsSpan(); input_is_nchw = scales_data[1] == 1.0F; const float scale_n = scales_data[0]; @@ -287,7 +287,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N return false; } - Initializer unpacked_tensor(*sizes); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *sizes); auto sizes_data = unpacked_tensor.DataAsSpan(); input_is_nchw = sizes_data[1] == input_shape[1]; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc index 7509fd15f1c5e..aa715068f432c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc @@ -104,7 +104,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No return false; } - Initializer unpacked_tensor(*splits); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *splits); auto splits_span = unpacked_tensor.DataAsSpan(); uint32_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), SafeInt(0)); if (sum_of_splits != split_dims_at_axis) { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index bc528c89f2be3..04f6349b250d9 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1594,6 +1594,11 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // Add node and node args // If node output is also parent graph output, the output will be added to the // subgraph's output list + // + // Initializers that refer to a memory location in OrtValue + // can not be handled by TRT (unlike those that are on disk). + // This prevents us from sharing the data and we have to make a copy here. + constexpr const bool load_initializers_inline_true = true; std::vector subgraph_output_names; for (const auto& index : group.first) { const auto& node = graph.GetNode(node_index[index]); @@ -1601,24 +1606,15 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t for (auto input : node->InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } for (auto input : node->ImplicitInputDefs()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } + for (auto output : node->OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); outputs.push_back(&n_output); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 6c5e1a1f0a8d3..2b8603b1555da 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -465,7 +465,7 @@ class NvExecutionProvider : public IExecutionProvider { * and save those information in subgraph context data structure. It's useful for building a valid graph and * make Graph::Resolve() happy especially when dealing with nested control-flow op graph. */ - void BuildSubGraphContext(const Graph& build_graph) const; + void BuildSubGraphContext(Graph& build_graph) const; /** * Set outer scope values for subgraphs and add thoes values as top-level graph's inputs if needed. diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 860cfb5713903..24e8892622175 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -14,6 +14,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/common/inlined_containers.h" namespace onnxruntime { namespace openvino_ep { @@ -643,16 +644,12 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph, const std::string& initializer_name) { // Get the initializer from source graph - const auto& src_initializers = src_graph.GetAllInitializedTensors(); - auto init_iter = src_initializers.find(initializer_name); - - if (init_iter == src_initializers.end()) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (!src_graph.GetInitializedTensor(initializer_name, tensor_proto)) { // Initializer not found return; } - const auto* tensor_proto = init_iter->second; - // Create TypeProto for the initializer auto type_proto = ONNX_NAMESPACE::TypeProto::Create(); auto* tensor_type = type_proto->mutable_tensor_type(); @@ -789,17 +786,21 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, } // Copy initializers to dst graph. + const auto& initializers = src_graph.GetAllInitializedTensors(); - std::unordered_set current_scope_initializer_set; - - auto& initializers = src_graph.GetAllInitializedTensors(); + InlinedHashSet current_scope_initializer_set; + current_scope_initializer_set.reserve(initializers.size()); // Sort initializers to maintain consistency in model proto created across inference requests - std::vector const_inits; - for (auto& it : initializers) { - const_inits.push_back(it.first); + + InlinedVector all_inits; + all_inits.reserve(initializers.size()); + for (auto it = initializers.cbegin(), end = initializers.cend(); it != end; ++it) { + all_inits.push_back(it); } - std::sort(const_inits.begin(), const_inits.end()); + std::sort(all_inits.begin(), all_inits.end(), [](const auto& i1, const auto& i2) { + return i1->first < i2->first; + }); // initialize map for creating metadata for initilizers with external weights auto& metadata = shared_weights.metadata; @@ -832,41 +833,53 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, metadata.emplace(key, std::move(value)); }; - // Handle constant initializers - for (auto& it : const_inits) { - const auto& initializer_tensor = *initializers.at(it); + // Handle initializers + for (const auto& it : all_inits) { + const auto& [name, init] = *it; + const auto& initializer_tensor = *init; + + std::unique_ptr init_with_data; + ORT_RETURN_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(initializer_tensor, init_with_data)); // Check if the initializer has external data - if (initializer_tensor.has_data_location() && - initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && + if (!init_with_data && + utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { insert_metadata(initializer_tensor); // Add initializer with external data as input - AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, it); - + AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); } else { // Add as an initialized tensor if it does not have external data - if (initializers_to_keep.count(it)) - dst_graph.AddInitializedTensor(*(initializers.at(it))); + if (initializers_to_keep.count(name) > 0) { + if (init_with_data) { + dst_graph.AddInitializedTensor(*init_with_data); + } else { + dst_graph.AddInitializedTensor(initializer_tensor); + } + } } - current_scope_initializer_set.insert(it); + current_scope_initializer_set.insert(name); } - // Handle outer-scope constant initializers + // Handle outer-scope initializers for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { const auto& node = src_graph.GetNode(node_idx); for (const auto& input : node->InputDefs()) { - if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { + if (current_scope_initializer_set.count(input->Name()) > 0) { continue; } if (src_graph.IsConstantInitializer(input->Name(), true)) { const auto& initializer_tensor = *src_graph.GetConstantInitializer(input->Name(), true); + + std::unique_ptr init_with_data; + ORT_RETURN_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(initializer_tensor, init_with_data)); + // Check if the initializer has external data - if (initializer_tensor.has_data_location() && - initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && + if (!init_with_data && + utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { insert_metadata(initializer_tensor); @@ -876,7 +889,11 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, } else { // Add as an initialized tensor if it does not have external data if (initializers_to_keep.count(input->Name())) { - dst_graph.AddInitializedTensor(*(src_graph.GetConstantInitializer(input->Name(), true))); + if (init_with_data) { + dst_graph.AddInitializedTensor(*init_with_data); + } else { + dst_graph.AddInitializedTensor(initializer_tensor); + } } } diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index e9343e2b2e06a..312733cb2ba0f 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -102,8 +102,8 @@ RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto graph_outputs = graph_viewer.GetOutputs(); // Add initializer to graph_viewer const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - graph_build.AddInitializedTensor(*(tensor.second)); + for (const auto& [name, _] : init_tensors) { + graph_utils::MakeInitializerCopyIfNotExist(graph_viewer.GetGraph(), graph_build, name); } ORT_ENFORCE(graph_build.Resolve().IsOK()); diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 4281b5e53c5fd..0e0f559d2e0f1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -27,7 +27,7 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg namespace { bool GetClipMinMaxImpl(std::function get_const_initializer, - const Node& node, float& min, float& max, const logging::Logger& logger) { + const Graph& graph, const Node& node, float& min, float& max, const logging::Logger& logger) { const auto& node_name = node.Name(); int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) { @@ -50,7 +50,7 @@ bool GetClipMinMaxImpl(std::function()[0]; @@ -97,7 +97,7 @@ bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min [&graph_viewer](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* { return graph_viewer.GetConstantInitializer(name); }, - node, min, max, logger); + graph_viewer.GetGraph(), node, min, max, logger); } NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 58d4461c7c32a..4d3ae4f4a7e07 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -418,8 +418,32 @@ inline std::unique_ptr MakeComputeCapability(const GraphViewe return g_host->Utils__MakeComputeCapability(graph_viewer, group, generate_metadef_name, execution_provider_name, drop_constant_initializers); } + +inline Status GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) { + return g_host->Utils__GetTensorProtoWithDataIfInMemory(tensor_proto, result); +} + +inline bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) { + return g_host->Utils__HasExternalDataInMemory(ten_proto); +} + } // namespace utils +namespace graph_utils { +inline NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { + return g_host->GraphUtils__AddInitializerWithExternalData(graph, new_initializer); +} +inline void MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, const std::string& name, + bool load_inline = false) { + g_host->GraphUtils__MakeInitializerCopyIfNotExist(src_graph, dst_graph, name, load_inline); +} + +inline Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name) { + return g_host->GraphUtils__ConvertInMemoryDataToInline(graph, name); +} +} // namespace graph_utils + namespace QDQ { inline std::pair>, std::unordered_map> GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { @@ -436,6 +460,18 @@ inline Env& GetDefaultEnv() { return g_host->Env__Default(); } +template +inline const T* Initializer::data() const { + constexpr const int data_type = static_cast(utils::GetONNXTensorElementDataType()); + return reinterpret_cast(g_host->Initializer__data(*this_ptr_, data_type)); +} + +template +inline T* Initializer::data() { + constexpr const int data_type = static_cast(utils::GetONNXTensorElementDataType()); + return reinterpret_cast(g_host->Initializer__mutable_data(*this_ptr_, data_type)); +} + } // namespace onnxruntime #define CREATE_MESSAGE(logger, severity, category, datatype) \ diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f20760fcc86fd..c056e454d0fc9 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -37,9 +37,11 @@ namespace onnxruntime { struct ProviderHost; struct ProviderHostCPU; +class ExternalDataInfo; class PhiloxGenerator; using ProviderType = const std::string&; class RandomGenerator; +class Initializer; class IOnnxRuntimeOpSchemaCollection; struct ModelSavingOptions; @@ -977,6 +979,12 @@ struct ProviderHost { const std::function& generate_metadef_name, const std::string& execution_provider_name, bool drop_constant_initializers) = 0; + + virtual Status Utils__GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) = 0; + + virtual bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) = 0; + // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, @@ -1004,6 +1012,8 @@ struct ProviderHost { virtual const std::unordered_map& Graph__DomainToVersionMap(const Graph* p) const noexcept = 0; virtual Status Graph__Resolve(Graph* p) = 0; virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0; + // We pass OrtValue by reference here (as opposed to the original Graph function) to avoid header inclusion + virtual Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, NodeAttributes&& attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const Node& other) = 0; @@ -1099,6 +1109,37 @@ struct ProviderHost { virtual std::unique_ptr ConstGraphNodes__cend(const ConstGraphNodes* p) = 0; virtual bool ConstGraphNodes__empty(const ConstGraphNodes* p) noexcept = 0; + // graph_util + virtual NodeArg& GraphUtils__AddInitializerWithExternalData(Graph& graph, + const ONNX_NAMESPACE::TensorProto& new_initializer) = 0; + virtual void GraphUtils__MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool load_inline) = 0; + + virtual Status GraphUtils__ConvertInMemoryDataToInline(Graph& graph, const std::string& name) = 0; + + // Initializer + virtual Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type, + std::string_view name, + gsl::span dims) = 0; + virtual Initializer* Initializer__constructor(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path = {}, + bool check_outer_scope = false) = 0; + + virtual void Initializer__destructor(Initializer*) = 0; + virtual void Initializer__ToProto(const Initializer&, + ONNX_NAMESPACE::TensorProto& tensor_proto) = 0; + virtual void Initializer__ToProtoWithOrtValue(const Initializer&, + ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) = 0; + virtual int Initializer__data_type(const Initializer&) = 0; + virtual const std::string& Initializer__name(const Initializer&) = 0; + virtual gsl::span Initializer__dims(const Initializer&) = 0; + virtual size_t Initializer__size(const Initializer&) = 0; + // data() template helper + virtual void* Initializer__mutable_data(Initializer&, int data_type) = 0; + virtual const void* Initializer__data(const Initializer&, int data_type) = 0; + virtual void* Initializer__mutable_data_raw(Initializer&) = 0; + virtual const void* Initializer__data_raw(const Initializer&) = 0; + // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 5fadd0b0966e8..6ee9ff5c73c4f 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1038,6 +1038,9 @@ struct Graph final { const std::unordered_map& DomainToVersionMap() const noexcept { return g_host->Graph__DomainToVersionMap(this); } Status Resolve() { return g_host->Graph__Resolve(this); } void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor) { return g_host->Graph__AddInitializedTensor(this, tensor); } + Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& ort_value) { + return g_host->Graph__AddInitializedOrtValue(this, tensor, ort_value); + } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, NodeAttributes&& attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, std::move(attributes), domain); } Node& AddNode(const Node& other) { return g_host->Graph__AddNode(this, other); } @@ -1177,6 +1180,69 @@ struct ConstGraphNodes final { PROVIDER_DISALLOW_ALL(ConstGraphNodes) }; +class Initializer { + public: + Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, + std::string_view name, + gsl::span dims) { + this_ptr_ = g_host->Initializer__constructor(data_type, name, dims); + } + + Initializer(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path = {}, + bool check_outer_scope = false) { + this_ptr_ = g_host->Initializer__constructor(graph, tensor_proto, model_path, check_outer_scope); + } + + ~Initializer() { + g_host->Initializer__destructor(this_ptr_); + } + + PROVIDER_DISALLOW_ALL(Initializer); + + void ToProto(ONNX_NAMESPACE::TensorProto& tensor_proto) const { + g_host->Initializer__ToProto(*this_ptr_, tensor_proto); + } + + void ToProtoWithOrtValue(ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) const { + g_host->Initializer__ToProtoWithOrtValue(*this_ptr_, tensor_proto, ort_value); + } + + int data_type() const { + return g_host->Initializer__data_type(*this_ptr_); + } + + const std::string& name() const { + return g_host->Initializer__name(*this_ptr_); + } + + gsl::span dims() const { + return g_host->Initializer__dims(*this_ptr_); + } + + size_t size() const { + return g_host->Initializer__size(*this_ptr_); + } + + // See definition for the below templates in provider_api.h + template + const T* data() const; + + template + T* data(); + + const void* data_raw() const { + return g_host->Initializer__data_raw(*this_ptr_); + } + + void* mutable_data_raw() { + return g_host->Initializer__mutable_data_raw(*this_ptr_); + } + + private: + Initializer* this_ptr_; +}; + struct OpKernelContext final { template const T& RequiredInput(int index) const; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fc8281ce51a1b..ece7583cfd135 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -46,7 +46,8 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; namespace { // Check if cycle exists in the graph after partitioning -bool FindCycleHelper(size_t i, const std::list* adjacency_map, bool visited[], bool* st, std::vector& cycles) { +bool FindCycleHelper(size_t i, gsl::span> adjacency_map, gsl::span visited, gsl::span st, + InlinedVector& cycles) { if (!visited[i]) { visited[i] = true; st[i] = true; @@ -263,7 +264,6 @@ struct ShutdownProtobuf { } g_protobuf; namespace onnxruntime { - namespace cuda { template <> void Impl_Cast( @@ -2204,28 +2204,22 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // subgraph's output list std::vector subgraph_output_names; for (const auto& index : group.first) { + // Initializers that refer to a memory location in OrtValue + // can not be handled by TRT (unlike those that are on disk). + // This prevents us from sharing the data and we have to make a copy here. + constexpr const bool load_initializers_inline_true = true; const auto& node = graph.GetNode(node_index[index]); std::vector inputs, outputs; for (auto input : node->InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } for (auto input : node->ImplicitInputDefs()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } for (auto output : node->OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); @@ -2471,7 +2465,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Create adjacency list size_t graph_size = node_to_index_map.size(); - std::list* adjacency_map = new std::list[graph_size]; + std::vector> adjacency_map(graph_size); for (const auto& node : node_to_outputs_map) { for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) { const auto& loc = input_to_nodes_map.find(*iter); @@ -2486,14 +2480,14 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& } // Check cycle in the graph - bool* visited = new bool[graph_size]; - bool* st = new bool[graph_size]; + InlinedVector visited(graph_size); + InlinedVector st(graph_size); for (size_t i = 0; i < graph_size; ++i) { visited[i] = false; st[i] = false; } - std::vector cycles; + InlinedVector cycles; bool has_cycle = false; for (size_t i = 0; i < graph_size; ++i) { if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) { @@ -2514,10 +2508,6 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& } } } - - delete[] adjacency_map; - delete[] visited; - delete[] st; } return cycle_detected; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index b00c800999f3b..7e02cf7590f66 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -540,7 +540,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { * and save those information in subgraph context data structure. It's useful for building a valid graph and * make Graph::Resolve() happy especially when dealing with nested control-flow op graph. */ - void BuildSubGraphContext(const Graph& build_graph) const; + void BuildSubGraphContext(Graph& build_graph) const; /** * Set outer scope values for subgraphs and add thoes values as top-level graph's inputs if needed. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index b99cb4f52ed59..c123a7d8d4590 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -71,7 +71,7 @@ bool TensorrtExecutionProvider::IsLocalValue(const Graph& graph, * and save those information in subgraph context data structure. It's useful for building a valid graph and * make Graph::Resolve() happy especially when dealing with nested control-flow op graph. */ -void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { +void TensorrtExecutionProvider::BuildSubGraphContext(Graph& graph) const { // Iterate all the nodes and recurse into inner most subgraph first for (int i = 0; i < graph.MaxNodeIndex(); ++i) { auto node = graph.GetNode(i); @@ -79,9 +79,9 @@ void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { continue; } - auto subgraph_map = node->GetAttributeNameToSubgraphMap(); + auto& subgraph_map = node->GetAttributeNameToMutableSubgraphMap(); for (auto& entry : subgraph_map) { - const Graph* subgraph = entry.second; + Graph* subgraph = entry.second; BuildSubGraphContext(*subgraph); } } @@ -121,6 +121,7 @@ void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { } // This input arg is not the output of another node so must come from either a graph input or an initializer. context->inputs_and_initializers[input->Name()] = input; + ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(graph, input->Name())); } } } diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h index 19cbe4e7f3e48..f4bf2be17ee56 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h @@ -72,7 +72,7 @@ class PadOpBuilder : public BaseOpBuilder { return false; } - Initializer unpacked_tensor(*pads_initializer); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *pads_initializer); auto tensor_data = unpacked_tensor.DataAsSpan(); for (size_t i = 0; i < unpacked_tensor.size(); i++) { if (tensor_data[i] < 0) { diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h index e08416bda70d4..b58d272d011b1 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h @@ -63,7 +63,7 @@ class SplitOpBuilder : public BaseOpBuilder { LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; return false; } - Initializer unpacked_tensor(*splits); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *splits); auto split_sizes_ = unpacked_tensor.DataAsSpan(); splits_list.assign(split_sizes_.begin(), split_sizes_.end()); split_provided = true; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc index 5d2f701ceac20..5034cb5a3525c 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc @@ -425,7 +425,7 @@ void GetQuantizationScaleAndZeroPoint( if (!s) { LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; }; - Initializer unpacked_tensor(*s, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *s, model_path); scale = unpacked_tensor.DataAsSpan()[0]; // per channel quantized handling @@ -442,7 +442,7 @@ void GetQuantizationScaleAndZeroPoint( if (!s) { LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; }; - Initializer unpacked_tensor(*s, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *s, model_path); bool is_i8_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT8; // some qdq conv bias is int32 quantized bool is_int32_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT32; diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 2adf8339b4b66..4f9243e592009 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -362,7 +362,7 @@ TensorQuantType GetTensorQuantType(const NodeUnit& node_unit, int32_t io_index, } else if (scales_dim == tensor_shape[0]) { // default 0 for zero-point if zero_dim == 0 if (zero_tensor != nullptr) { - Initializer zp_val(*zero_tensor, node_unit.ModelPath()); + Initializer zp_val(graph_viewer.GetGraph(), *zero_tensor, node_unit.ModelPath()); auto zero_points = zp_val.DataAsSpan(); for (size_t i = 0; i < zp_val.size(); i++) { if (zero_points[i] != 0) { diff --git a/onnxruntime/core/providers/xnnpack/math/softmax.cc b/onnxruntime/core/providers/xnnpack/math/softmax.cc index 6786c29e1f056..c0246c2f0da34 100644 --- a/onnxruntime/core/providers/xnnpack/math/softmax.cc +++ b/onnxruntime/core/providers/xnnpack/math/softmax.cc @@ -31,13 +31,13 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph // idealy, QlinearSoftmax or QDQSoftmax will keep this output scale and zp, but we have to handle some // qdq models converted from other framework auto [scale_tensor, zero_tensor] = GetQuantizationZeroPointAndScale(graph, node_unit.Outputs()[0]); - Initializer q_scale(*scale_tensor, node_unit.ModelPath()); + Initializer q_scale(graph.GetGraph(), *scale_tensor, node_unit.ModelPath()); if (fabs(q_scale.DataAsSpan()[0] - 1.0f / 256.0f) > 0.0001f) { break; } if (zero_tensor) { - Initializer q_zp(*zero_tensor, node_unit.ModelPath()); + Initializer q_zp(graph.GetGraph(), *zero_tensor, node_unit.ModelPath()); if (q_zp.DataAsSpan()[0] != 0) { break; } diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 82cf1fc9bb87d..0bb1194643743 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -65,7 +65,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, // check the scale for the second dim is 1 or the size of the second dim matches the input shape. // if not, it is not the C dim as a Resize will not change the number of channels. if (scale_tensor) { - const Initializer scale_val(*scale_tensor, node_unit.ModelPath()); + const Initializer scale_val(graph_viewer.GetGraph(), *scale_tensor, node_unit.ModelPath()); const auto scales = scale_val.DataAsSpan(); if (scales[1] != 1.0F) { break; @@ -90,7 +90,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, } if (size_tensor) { - const Initializer size_val(*size_tensor, node_unit.ModelPath()); + const Initializer size_val(graph_viewer.GetGraph(), *size_tensor, node_unit.ModelPath()); if (size_val.DataAsSpan()[1] != x_shape->dim(1).dim_value()) { break; } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 7dfed4cbe787d..c94f2ea3b4d0b 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -583,21 +583,19 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernel onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); - auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); + auto tensor = onnxruntime::Tensor{type, tensor_shape, std::move(alloc_ptr)}; // Deserialize TensorProto into pre-allocated, empty Tensor. // TODO: here the TensorProto loses model path information, so it cannot be an external tensor. status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), std::filesystem::path(), - tensor_proto, *tensorp); + tensor_proto, tensor); if (!status.IsOK()) { return onnxruntime::ToOrtStatus(status); } // Initialize OrtValue from Tensor. - auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); auto value = std::make_unique(); - value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - + onnxruntime::Tensor::InitOrtValue(std::move(tensor), *value); *out = value.release(); return nullptr; }); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c34769f43ae1d..e59127b2c2c8c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1300,7 +1300,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool transform_layout_fn = [this](Graph& graph_to_transform, bool& modified, const IExecutionProvider& execution_provider, const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); ORT_RETURN_IF_ERROR_SESSIONID_( layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn)); @@ -1716,7 +1716,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, [](Graph& graph_to_transform, bool& modified, const IExecutionProvider& execution_provider, const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); return layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); }; @@ -1749,8 +1749,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + concurrency::ThreadPool* intra_op_thread_pool) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1758,7 +1757,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger, - optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); + optimizers_to_disable, intra_op_thread_pool); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -1841,7 +1840,7 @@ common::Status InferenceSession::Initialize() { #ifdef DISABLE_EXTERNAL_INITIALIZERS const InitializedTensorSet& initializers = graph.GetAllInitializedTensors(); for (const auto& it : initializers) { - if (utils::HasExternalData(*it.second)) { + if (utils::HasExternalData(*it.second) && !utils::HasExternalDataInMemory(*it.second)) { return common::Status(common::ONNXRUNTIME, common::FAIL, "Initializer tensors with external data is not allowed."); } @@ -2193,8 +2192,7 @@ common::Status InferenceSession::Initialize() { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, - cpu_ep, GetIntraOpThreadPoolToUse(), - session_state_->GetMutableBufferedTensors())); + cpu_ep, GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3397,8 +3395,7 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, - GetIntraOpThreadPoolToUse(), - session_state_->GetMutableBufferedTensors()); + GetIntraOpThreadPoolToUse()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3409,8 +3406,7 @@ common::Status InferenceSession::AddPredefinedTransformers( return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, logger, optimizers_to_disable_, - GetIntraOpThreadPoolToUse(), - session_state_->GetMutableBufferedTensors()); + GetIntraOpThreadPoolToUse()); } }(); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 197fb4320e6bf..93fd17ae1c5c9 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -47,6 +47,7 @@ #include "core/session/provider_bridge_ort.h" #include "core/util/math.h" #include "onnx/shape_inference/implementation.h" +#include "core/optimizer/initializer.h" #ifdef ENABLE_TRAINING #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -1238,6 +1239,15 @@ struct ProviderHostImpl : ProviderHost { execution_provider_name, drop_constant_initializers); } + Status Utils__GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) override { + return onnxruntime::utils::GetTensorProtoWithDataIfInMemory(tensor_proto, result); + } + + bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) override { + return onnxruntime::utils::HasExternalDataInMemory(ten_proto); + } + // Model (wrapped) std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, @@ -1273,6 +1283,8 @@ struct ProviderHostImpl : ProviderHost { Status Graph__Resolve(Graph* p) override { return p->Resolve(); } void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) override { p->AddInitializedTensor(tensor); } + Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, + const OrtValue& value) override { return p->AddInitializedOrtValue(tensor, value); } Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) override { return p->AddNode(name, op_type, description, input_args, output_args, attributes, domain); } @@ -1423,6 +1435,75 @@ struct ProviderHostImpl : ProviderHost { } bool ConstGraphNodes__empty(const ConstGraphNodes* p) noexcept override { return p->empty(); } + NodeArg& GraphUtils__AddInitializerWithExternalData(Graph& graph, + const ONNX_NAMESPACE::TensorProto& new_initializer) override { + return graph_utils::AddInitializerWithExternalData(graph, new_initializer); + } + + void GraphUtils__MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool load_in_memory) override { + graph_utils::MakeInitializerCopyIfNotExist(src_graph, dst_graph, name, load_in_memory); + } + + // Initializer (wrapped) + Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type, + std::string_view name, + gsl::span dims) override { + return new Initializer(data_type, name, dims); + } + + Initializer* Initializer__constructor(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path, + bool check_outer_scope) override { + return new Initializer(graph, tensor_proto, model_path, check_outer_scope); + } + void Initializer__destructor(Initializer* p) override { delete p; } + void Initializer__ToProto(const Initializer& initializer, + ONNX_NAMESPACE::TensorProto& tensor_proto) override { + initializer.ToProto(tensor_proto); + } + void Initializer__ToProtoWithOrtValue(const Initializer& initializer, + ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) override { + initializer.ToProtoWithOrtValue(tensor_proto, ort_value); + } + int Initializer__data_type(const Initializer& initializer) override { + return initializer.data_type(); + } + const std::string& Initializer__name(const Initializer& initializer) override { + return initializer.name(); + } + gsl::span Initializer__dims(const Initializer& initializer) override { + return initializer.dims(); + } + size_t Initializer__size(const Initializer& initializer) override { + return initializer.size(); + } + + void* Initializer__mutable_data(Initializer& initializer, int data_type) override { + if (data_type != initializer.data_type()) { + throw std::invalid_argument("Initializer mutable data type mismatch"); + } + return initializer.mutable_data_raw(); + } + + const void* Initializer__data(const Initializer& initializer, int data_type) override { + if (data_type != initializer.data_type()) { + throw std::invalid_argument("Initializer data type mismatch"); + } + return initializer.data_raw(); + } + + void* Initializer__mutable_data_raw(Initializer& initializer) override { + return initializer.mutable_data_raw(); + } + const void* Initializer__data_raw(const Initializer& initializer) override { + return initializer.data_raw(); + } + + Status GraphUtils__ConvertInMemoryDataToInline(Graph& graph, const std::string& name) override { + return graph_utils::ConvertInMemoryDataToInline(graph, name); + } + // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc index 7986082da06f7..ad8c9c13bc9ca 100644 --- a/onnxruntime/lora/adapter_format_utils.cc +++ b/onnxruntime/lora/adapter_format_utils.cc @@ -127,7 +127,7 @@ struct ReadDataForBigEndian { // If BE, we a allocate memory within the tensor and copy there swapping bytes [[maybe_unused]] static Status CreateOrtValueForBePlatforms(const Parameter& param, const MLDataType elem_type, gsl::span shape, OrtValue& result) { - static const AllocatorPtr cpu_allocator = std::make_shared(); + static const AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); auto src_span = ReinterpretAsSpan( gsl::make_span(param.raw_data()->data(), param.raw_data()->size())); diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 1515879f61419..5ac1394caf65a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -321,7 +321,7 @@ inline const PySessionOptions& GetDefaultCPUSessionOptions() { } inline AllocatorPtr& GetAllocator() { - static AllocatorPtr alloc = std::make_shared(); + static AllocatorPtr alloc = CPUAllocator::DefaultInstance(); return alloc; } diff --git a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc index 467c5e773589a..2ca04235329ef 100644 --- a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc +++ b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc @@ -116,7 +116,7 @@ ONNX_NAMESPACE::TensorProto CreateInitializer(const std::string& name, } if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto(&tp); + utils::ConvertRawDataInTensorProto(tp); } return tp; @@ -262,7 +262,7 @@ TEST(FlatbufferUtilsTest, ExternalWriteReadWithLoadInitializers) { ONNX_NAMESPACE::TensorProto initializer; ASSERT_STATUS_OK(LoadInitializerOrtFormat(*fbs_tensor, initializer, options, reader)); if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto(&initializer); + utils::ConvertRawDataInTensorProto(initializer); } loaded_initializers.emplace_back(std::move(initializer)); // also check that the loaded flatbuffer tensors have accurately written to the external_data_offset field diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index eaebac177ca91..c957f54e51a9c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -19,6 +19,7 @@ using json = nlohmann::json; #include "core/framework/allocation_planner.h" #include "core/session/inference_session.h" #include "core/graph/model.h" +#include "core/graph/graph_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/util/thread_utils.h" @@ -1022,7 +1023,7 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) { tensor.add_float_data(1.0f); tensor.set_data_type(TensorProto_DataType_FLOAT); tensor.set_name("init_data"); - main_graph.AddInitializedTensor(tensor); + graph_utils::AddInitializerWithExternalData(main_graph, tensor); // Main graph's inputs/outputs main_graph.SetInputs({&abs_data_in, &if_in}); @@ -1129,7 +1130,7 @@ TEST_F(PlannerTest, LocationPlanningForInitializersUsedOnDifferentDevicesInMainG tensor.add_int64_data(1); tensor.set_data_type(TensorProto_DataType_INT64); tensor.set_name("init_data"); - main_graph.AddInitializedTensor(tensor); + graph_utils::AddInitializerWithExternalData(main_graph, tensor); // Main graph's inputs/outputs main_graph.SetInputs({&abs_data_in, &if_in}); @@ -1554,7 +1555,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 64 * 3 * 7 * 7; ++i) conv_0_weight_tensor.add_float_data(0.234f); conv_0_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_0_weight_tensor.set_name("conv_0_weight"); - main_graph.AddInitializedTensor(conv_0_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_0_weight_tensor); ONNX_NAMESPACE::TensorProto conv_1_weight_tensor; conv_1_weight_tensor.add_dims(64L); @@ -1564,7 +1565,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { conv_1_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); for (int i = 0; i < 64 * 64; ++i) conv_1_weight_tensor.add_float_data(1.017f); conv_1_weight_tensor.set_name("conv_1_weight"); - main_graph.AddInitializedTensor(conv_1_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_1_weight_tensor); ONNX_NAMESPACE::TensorProto conv_2_weight_tensor; conv_2_weight_tensor.add_dims(64L); @@ -1574,7 +1575,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 64 * 64 * 3 * 3; ++i) conv_2_weight_tensor.add_float_data(2.317f); conv_2_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_2_weight_tensor.set_name("conv_2_weight"); - main_graph.AddInitializedTensor(conv_2_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_2_weight_tensor); ONNX_NAMESPACE::TensorProto conv_3_weight_tensor; conv_3_weight_tensor.add_dims(256L); @@ -1584,7 +1585,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 256 * 64; ++i) conv_3_weight_tensor.add_float_data(1.256f); conv_3_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_3_weight_tensor.set_name("conv_3_weight"); - main_graph.AddInitializedTensor(conv_3_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_3_weight_tensor); ONNX_NAMESPACE::TensorProto conv_4_weight_tensor; conv_4_weight_tensor.add_dims(256L); @@ -1594,7 +1595,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 256 * 64; ++i) conv_4_weight_tensor.add_float_data(1.913f); conv_4_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_4_weight_tensor.set_name("conv_4_weight"); - main_graph.AddInitializedTensor(conv_4_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_4_weight_tensor); auto& conv_0_weight = main_graph.GetOrCreateNodeArg("conv_0_weight", &conv_0_weight_type); auto& conv_1_weight = main_graph.GetOrCreateNodeArg("conv_1_weight", &conv_1_weight_type); @@ -1607,35 +1608,35 @@ TEST_F(PlannerTest, ParaPlanCreation) { conv_0_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_0_bias_tensor.set_name("conv_0_bias"); for (int i = 0; i < 64; ++i) conv_0_bias_tensor.add_float_data(1.123f); - main_graph.AddInitializedTensor(conv_0_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_0_bias_tensor); ONNX_NAMESPACE::TensorProto conv_1_bias_tensor; conv_1_bias_tensor.add_dims(64L); for (int i = 0; i < 64; ++i) conv_1_bias_tensor.add_float_data(2.234f); conv_1_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_1_bias_tensor.set_name("conv_1_bias"); - main_graph.AddInitializedTensor(conv_1_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_1_bias_tensor); ONNX_NAMESPACE::TensorProto conv_2_bias_tensor; conv_2_bias_tensor.add_dims(64L); for (int i = 0; i < 64; ++i) conv_2_bias_tensor.add_float_data(0.121f); conv_2_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_2_bias_tensor.set_name("conv_2_bias"); - main_graph.AddInitializedTensor(conv_2_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_2_bias_tensor); ONNX_NAMESPACE::TensorProto conv_3_bias_tensor; conv_3_bias_tensor.add_dims(256L); for (int i = 0; i < 256; ++i) conv_3_bias_tensor.add_float_data(1.201f); conv_3_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_3_bias_tensor.set_name("conv_3_bias"); - main_graph.AddInitializedTensor(conv_3_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_3_bias_tensor); ONNX_NAMESPACE::TensorProto conv_4_bias_tensor; conv_4_bias_tensor.add_dims(256L); for (int i = 0; i < 256; ++i) conv_4_bias_tensor.add_float_data(0.897f); conv_4_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_4_bias_tensor.set_name("conv_4_bias"); - main_graph.AddInitializedTensor(conv_4_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_4_bias_tensor); auto& conv_0_bias = main_graph.GetOrCreateNodeArg("conv_0_bias", &conv_0_bias_type); auto& conv_1_bias = main_graph.GetOrCreateNodeArg("conv_1_bias", &conv_1_bias_type); diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index e28327941dda4..b86f3efeefafd 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -16,6 +16,7 @@ #include "core/framework/session_state.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" +#include "core/graph/graph_utils.h" #include "core/graph/model.h" #include "core/graph/op.h" #include "core/providers/cpu/math/element_wise_ops.h" @@ -66,10 +67,7 @@ static common::Status LoadInferenceSessionFromModel(FenceCudaTestInferenceSessio tensor_proto.set_data_type(PROTO_DATATYPE); \ for (auto v : value) tensor_proto.PROTO_ADD_DATA(v); \ tensor_proto.set_name(name); \ - graph.AddInitializedTensor(tensor_proto); \ - TypeProto type_proto; \ - type_proto.mutable_tensor_type()->set_elem_type(PROTO_DATATYPE); \ - return graph.GetOrCreateNodeArg(name, &type_proto); \ + return graph_utils::AddInitializerWithExternalData(graph, tensor_proto); \ } CREATE_INITIALIZER_FUNC(float, TensorProto_DataType_FLOAT, add_float_data) diff --git a/onnxruntime/test/framework/endian_test.cc b/onnxruntime/test/framework/endian_test.cc index 7b8f56bd97073..694967c70d136 100644 --- a/onnxruntime/test/framework/endian_test.cc +++ b/onnxruntime/test/framework/endian_test.cc @@ -1,10 +1,10 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - #include "core/framework/endian.h" #include "core/framework/endian_utils.h" +#include "core/graph/onnx_protobuf.h" // For TensorProto +#include "core/framework/tensorprotoutils.h" // For ConvertRawDataInTensorProto #include +#include // For std::byte #include "gtest/gtest.h" @@ -47,6 +47,327 @@ TEST(EndianTest, SwapByteOrderCopy) { } } +// Test fixture for SwapByteOrderInplace tests +class SwapByteOrderInplaceTest : public ::testing::Test {}; + +TEST_F(SwapByteOrderInplaceTest, ElementSize1) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}}; + std::vector expected_data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(1, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize2_SingleElement) { + std::vector data = {std::byte{0x01}, std::byte{0x02}}; + std::vector expected_data = {std::byte{0x02}, std::byte{0x01}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(2, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize2_MultipleElements) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, std::byte{0x05}, std::byte{0x06}}; + std::vector expected_data = { + std::byte{0x02}, std::byte{0x01}, std::byte{0x04}, std::byte{0x03}, std::byte{0x06}, std::byte{0x05}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(2, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize4_SingleElement) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}}; + std::vector expected_data = { + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(4, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize4_MultipleElements) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, + std::byte{0x05}, std::byte{0x06}, std::byte{0x07}, std::byte{0x08}}; + std::vector expected_data = { + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}, + std::byte{0x08}, std::byte{0x07}, std::byte{0x06}, std::byte{0x05}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(4, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize8_SingleElement) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, + std::byte{0x05}, std::byte{0x06}, std::byte{0x07}, std::byte{0x08}}; + std::vector expected_data = { + std::byte{0x08}, std::byte{0x07}, std::byte{0x06}, std::byte{0x05}, + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(8, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize8_MultipleElements) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, + std::byte{0x05}, std::byte{0x06}, std::byte{0x07}, std::byte{0x08}, + std::byte{0x11}, std::byte{0x12}, std::byte{0x13}, std::byte{0x14}, + std::byte{0x15}, std::byte{0x16}, std::byte{0x17}, std::byte{0x18}}; + std::vector expected_data = { + std::byte{0x08}, std::byte{0x07}, std::byte{0x06}, std::byte{0x05}, + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}, + std::byte{0x18}, std::byte{0x17}, std::byte{0x16}, std::byte{0x15}, + std::byte{0x14}, std::byte{0x13}, std::byte{0x12}, std::byte{0x11}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(8, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, EmptyBuffer) { + std::vector data = {}; + std::vector expected_data = {}; + gsl::span data_span = gsl::make_span(data); + + // Should not crash or throw for valid element sizes, e.g., 2 or 4 + // The ORT_ENFORCE checks will pass as 0 % element_size == 0 + // The loop for swapping will not execute. + utils::SwapByteOrderInplace(2, data_span); + EXPECT_EQ(data, expected_data); + + utils::SwapByteOrderInplace(4, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize3_OddElementSize) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}, std::byte{0x05}, std::byte{0x06}}; + std::vector expected_data = { + std::byte{0x03}, std::byte{0x02}, std::byte{0x01}, + std::byte{0x06}, std::byte{0x05}, std::byte{0x04}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(3, data_span); + EXPECT_EQ(data, expected_data); +} + +// Test fixture for ConvertRawDataInTensorProto tests +class ConvertRawDataInTensorProtoTest : public ::testing::Test { + protected: + // Helper function to set up a TensorProto with float data + void SetupFloatTensor(ONNX_NAMESPACE::TensorProto& tensor, const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (float value : values) { + tensor.add_float_data(value); + } + } + + // Helper function to set up a TensorProto with int32 data + void SetupInt32Tensor(ONNX_NAMESPACE::TensorProto& tensor, const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + for (int32_t value : values) { + tensor.add_int32_data(value); + } + } + + // Helper function to set up a TensorProto with int16 data (stored in int32 container) + void SetupInt16Tensor(ONNX_NAMESPACE::TensorProto& tensor, const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT16); + for (int16_t value : values) { + tensor.add_int32_data(value); + } + } + + // Helper function to set up a TensorProto with raw data + template + void SetupRawDataTensor(ONNX_NAMESPACE::TensorProto& tensor, ONNX_NAMESPACE::TensorProto_DataType data_type, + const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(data_type); + tensor.set_raw_data(values.data(), values.size() * sizeof(T)); + } + + // Helper to compare float data before and after conversion + void CompareFloatData(const ONNX_NAMESPACE::TensorProto& tensor, const std::vector& expected_values) { + ASSERT_EQ(tensor.float_data_size(), static_cast(expected_values.size())); + for (int i = 0; i < tensor.float_data_size(); i++) { + // We swap bytes so the actual value might change if we're converting endianness + // But a double swap should restore the original value + if constexpr (endian::native == endian::little) { + EXPECT_EQ(tensor.float_data(i), expected_values[i]); + } else { + // Just verify the value is different after one swap on big-endian + // We can't predict the exact value without manual byte swapping + if (expected_values[i] != 0) { // Skip zero values as they're invariant to byte swapping + EXPECT_NE(tensor.float_data(i), expected_values[i]); + } + } + } + } + + // Helper to compare int32 data before and after conversion + void CompareInt32Data(const ONNX_NAMESPACE::TensorProto& tensor, const std::vector& expected_values) { + ASSERT_EQ(tensor.int32_data_size(), static_cast(expected_values.size())); + for (int i = 0; i < tensor.int32_data_size(); i++) { + // Same logic as float comparison + if constexpr (endian::native == endian::little) { + EXPECT_EQ(tensor.int32_data(i), expected_values[i]); + } else { + if (expected_values[i] != 0) { + EXPECT_NE(tensor.int32_data(i), expected_values[i]); + } + } + } + } +}; + +TEST_F(ConvertRawDataInTensorProtoTest, FloatData) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1.0f, 2.0f, 3.0f, 4.0f}; + SetupFloatTensor(tensor, values); + + // Save original values + std::vector original_values; + for (int i = 0; i < tensor.float_data_size(); i++) { + original_values.push_back(tensor.float_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + CompareFloatData(tensor, original_values); +} + +TEST_F(ConvertRawDataInTensorProtoTest, Int32Data) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1, 2, 3, 4}; + SetupInt32Tensor(tensor, values); + + // Save original values + std::vector original_values; + for (int i = 0; i < tensor.int32_data_size(); i++) { + original_values.push_back(tensor.int32_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + CompareInt32Data(tensor, original_values); +} + +TEST_F(ConvertRawDataInTensorProtoTest, Int16Data) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1, 2, 3, 4}; + SetupInt16Tensor(tensor, values); + + // Save original values + std::vector original_values; + for (int i = 0; i < tensor.int32_data_size(); i++) { + original_values.push_back(tensor.int32_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // When we swap bytes on int16 values stored in int32 containers, the test should pass + // on both little-endian and big-endian systems + ASSERT_EQ(tensor.int32_data_size(), static_cast(original_values.size())); + for (int i = 0; i < tensor.int32_data_size(); i++) { + EXPECT_EQ(tensor.int32_data(i), original_values[i]); + } +} + +TEST_F(ConvertRawDataInTensorProtoTest, RawFloatData) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1.0f, 2.0f, 3.0f, 4.0f}; + SetupRawDataTensor(tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, values); + + // Save original raw data + std::string original_raw_data = tensor.raw_data(); + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original bytes + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + EXPECT_EQ(tensor.raw_data(), original_raw_data); +} + +TEST_F(ConvertRawDataInTensorProtoTest, UInt8NoConversion) { + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + std::vector values = {1, 2, 3, 4}; + for (auto val : values) { + tensor.add_int32_data(val); + } + + // Save original data + std::vector original_values; + for (int i = 0; i < tensor.int32_data_size(); i++) { + original_values.push_back(tensor.int32_data(i)); + } + + // Convert - for 1-byte elements, no conversion should happen + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Verify no change occurred + ASSERT_EQ(tensor.int32_data_size(), static_cast(original_values.size())); + for (int i = 0; i < tensor.int32_data_size(); i++) { + EXPECT_EQ(tensor.int32_data(i), original_values[i]); + } +} + +TEST_F(ConvertRawDataInTensorProtoTest, DoubleConversionAndRestore) { + // Test with double values + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + std::vector values = {1.1, 2.2, 3.3, 4.4}; + for (auto val : values) { + tensor.add_double_data(val); + } + + // Save original data + std::vector original_values; + for (int i = 0; i < tensor.double_data_size(); i++) { + original_values.push_back(tensor.double_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert again - this should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Verify restored values + ASSERT_EQ(tensor.double_data_size(), static_cast(original_values.size())); + for (int i = 0; i < tensor.double_data_size(); i++) { + EXPECT_EQ(tensor.double_data(i), original_values[i]); + } +} + } // namespace test } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 76399743c97f8..6ad21fa9f5cf5 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -275,7 +275,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { graph, session_state.GetMutableFuncMgr(), [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); return layout_transformation::TransformLayoutForEP( graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); }, @@ -319,7 +319,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { GTEST_SKIP() << "CPU allocator does not support arena usage."; } - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); // Part 1: Feature turned ON (i.e.) allocate from non-arena memory { std::basic_ostringstream oss; diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 7bd6b47f52b7d..43de3a945526c 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -706,7 +706,7 @@ struct InsertIndices { std::vector indices(indices_data.cbegin(), indices_data.cend()); indices_tp.mutable_raw_data()->assign(reinterpret_cast(indices.data()), indices.size() * sizeof(T)); if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&indices_tp); + utils::ConvertRawDataInTensorProto(indices_tp); } } } diff --git a/onnxruntime/test/framework/test_tensor_loader.cc b/onnxruntime/test/framework/test_tensor_loader.cc index 73bf351b6c556..f423f9a542387 100644 --- a/onnxruntime/test/framework/test_tensor_loader.cc +++ b/onnxruntime/test/framework/test_tensor_loader.cc @@ -80,7 +80,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_allocator) { // save it to a buffer ASSERT_TRUE(p.SerializeToString(&s)); // deserialize it - AllocatorPtr tmp_allocator = std::make_shared(); + AllocatorPtr tmp_allocator = CPUAllocator::DefaultInstance(); OrtValue value; ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, tmp_allocator, value)); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index f6b7bdb1a001c..e2b54950e7b24 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1698,7 +1698,7 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { ONNX_NAMESPACE::TensorProto bad_name = original; bad_name.set_name("invalid"); - status = graph.ReplaceInitializedTensor(std::move(bad_name)); + status = graph.ReplaceInitializedTensor(bad_name, OrtValue()); ASSERT_FALSE(status.IsOK()); } @@ -1706,7 +1706,7 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { ONNX_NAMESPACE::TensorProto bad_type = original; bad_type.set_data_type(TensorProto_DataType_FLOAT16); - status = graph.ReplaceInitializedTensor(std::move(bad_type)); + status = graph.ReplaceInitializedTensor(bad_type, OrtValue()); ASSERT_FALSE(status.IsOK()); } @@ -1716,7 +1716,7 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { bad_dims.add_dims(2); bad_dims.add_dims(1); - status = graph.ReplaceInitializedTensor(std::move(bad_dims)); + status = graph.ReplaceInitializedTensor(bad_dims, OrtValue()); ASSERT_FALSE(status.IsOK()); } @@ -1726,26 +1726,39 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { valid_replacement.add_int32_data(3); valid_replacement.add_int32_data(4); - status = graph.ReplaceInitializedTensor(valid_replacement); + status = graph.ReplaceInitializedTensor(valid_replacement, OrtValue()); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto tensor_data_matches = [](const ONNX_NAMESPACE::TensorProto& a, const ONNX_NAMESPACE::TensorProto& b) { - if (a.int32_data_size() != b.int32_data_size()) return false; - for (int i = 0; i < a.int32_data_size(); ++i) { - if (a.int32_data(i) != b.int32_data(i)) return false; + auto tensor_data_matches = [](const Graph& graph, const ONNX_NAMESPACE::TensorProto& a, + const ONNX_NAMESPACE::TensorProto& b) -> bool { + // For simplicity. We do not want to deal with external and raw data combinations. + Tensor tensor_a; + EXPECT_TRUE(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(), a, tensor_a).IsOK()); + Tensor tensor_b; + EXPECT_TRUE(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(), b, tensor_b).IsOK()); + + EXPECT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, tensor_a.GetElementType()); + + if (tensor_a.GetElementType() != tensor_b.GetElementType()) { + return false; + } + if (tensor_a.Shape() != tensor_b.Shape()) { + return false; } - return true; + const auto span_a = tensor_a.DataAsSpan(); + const auto span_b = tensor_b.DataAsSpan(); + return std::equal(span_a.begin(), span_a.end(), span_b.begin()); }; // check retrieved tensor const ONNX_NAMESPACE::TensorProto* result; ASSERT_TRUE(graph.GetInitializedTensor(initializer_name, result)); - ASSERT_TRUE(tensor_data_matches(*result, valid_replacement)); + ASSERT_TRUE(tensor_data_matches(graph, *result, valid_replacement)); // check GraphProto content const ONNX_NAMESPACE::GraphProto graph_proto = graph.ToGraphProto(); ASSERT_EQ(graph_proto.initializer_size(), 1); - ASSERT_TRUE(tensor_data_matches(graph_proto.initializer(0), valid_replacement)); + ASSERT_TRUE(tensor_data_matches(graph, graph_proto.initializer(0), valid_replacement)); } } @@ -1822,13 +1835,13 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { const TensorProto* with_data = nullptr; ASSERT_TRUE(graph.GetInitializedTensor(initializer_name, with_data)); - // No longer has external data if (with_data) { - ASSERT_FALSE(utils::HasExternalData(*with_data)); + // This proto still has external data, but now it points to the OrtValue. + ASSERT_TRUE(utils::HasExternalData(*with_data)); const auto& original_tensor = ort_value.Get(); - Tensor replaced_tensor(original_tensor.DataType(), data_shape, std::make_shared()); - ASSERT_STATUS_OK(utils::TensorProtoToTensor(Env::Default(), tensor_data_dir_path, *with_data, - replaced_tensor)); + Tensor replaced_tensor; + ASSERT_STATUS_OK(utils::CreateTensorFromTensorProto(Env::Default(), tensor_data_dir_path, *with_data, + replaced_tensor)); ASSERT_EQ(original_tensor.GetElementType(), replaced_tensor.GetElementType()); const auto original_span = original_tensor.DataAsSpan(); const auto replaced_span = replaced_tensor.DataAsSpan(); @@ -2124,6 +2137,187 @@ TEST_F(GraphTest, SubgraphOutputIsOuterScopeValue) { ::testing::ContainsRegex("Subgraph output \\(.*\\) is an outer scope value being returned directly.")); } +static void CreateIntializerWithDataInMemory(const std::string& name, const AllocatorPtr& allocator, int64_t size, + TensorProto& tensor_proto, OrtValue& ort_value) { + TensorShape shape({size}); + Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, allocator, ort_value); + float v = 0; + auto* data = ort_value.GetMutable()->MutableData(); + for (int64_t i = 0; i < size; ++i) { + *data++ = v++; + } + + tensor_proto = utils::TensorToTensorProto(ort_value.Get(), name, true); +} + +TEST(GraphGetOrtValueInitializerTest, ReturnsOrtValueForExistingInitializer) { + Model model("TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + // Create a simple TensorProto initializer + const std::string name = "init1"; + auto allocator = CPUAllocator::DefaultInstance(); + constexpr const int64_t kTensorSize = 256; + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(name, allocator, kTensorSize, tensor_proto, ort_value); + + ASSERT_STATUS_OK(graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // Test retrieval + OrtValue retrieved; + EXPECT_TRUE(graph.GetOrtValueInitializer(name, retrieved, false)); + const Tensor& t = retrieved.Get(); + EXPECT_EQ(t.Shape().Size(), kTensorSize); +} + +TEST(GraphGetOrtValueInitializerTest, ReturnsFalseForNonExistentInitializer) { + Model model("TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + OrtValue retrieved; + EXPECT_FALSE(graph.GetOrtValueInitializer("does_not_exist", retrieved, false)); +} + +namespace { +// Casing only, do not add members +class NodeWrapper : public Node { + public: + Node::Definitions& MutableDefinitions() { + return Node::MutableDefinitions(); + } +}; +} // namespace + +TEST(GraphGetOrtValueInitializerTest, ReturnsOrtValueFromOuterScope) { + // Create parent graph with initializer + Model parent_model("ParentModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& parent_graph = parent_model.MainGraph(); + + const std::string outer_init_name = "outer_init"; + auto allocator = CPUAllocator::DefaultInstance(); + constexpr const int64_t kTensorSize = 256; + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(outer_init_name, allocator, kTensorSize, tensor_proto, ort_value); + + ASSERT_STATUS_OK(parent_graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // Create a node in parent graph that will be the parent node for the subgraph + TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + auto& input_arg = parent_graph.GetOrCreateNodeArg("node_input", &tensor_type); + auto& output_arg = parent_graph.GetOrCreateNodeArg("node_output", &tensor_type); + NodeArg* inputs[] = {&input_arg}; + NodeArg* outputs[] = {&output_arg}; + + // Create parent node with a subgraph attribute + auto& parent_node = parent_graph.AddNode("parent_node", "If", "parent node with subgraph", inputs, outputs); + // Add the initializer name to the parent node's implicit input defs + NodeArg* outer_init_nodearg = parent_graph.GetNodeArg(outer_init_name); + ASSERT_NE(outer_init_nodearg, nullptr); + { + // Test hack to tweak an internal structure. + auto& node_wrapper = static_cast(parent_node); + node_wrapper.MutableDefinitions().implicit_input_defs.push_back(outer_init_nodearg); + } + + // Create subgraph + GraphProto subgraph_proto; + subgraph_proto.set_name("Subgraph"); + Graph subgraph(parent_model, &subgraph_proto, parent_graph.DomainToVersionMap(), parent_model.IrVersion(), + nullptr, &parent_graph, &parent_node, DefaultLoggingManager().DefaultLogger(), false); + + // Test retrieval from outer scope + OrtValue retrieved; + EXPECT_TRUE(subgraph.GetOrtValueInitializer("outer_init", retrieved, true)); + const Tensor& t = retrieved.Get(); + EXPECT_EQ(t.Shape().Size(), kTensorSize); +} + +TEST_F(GraphTest, AddInitializedOrtValueWithExternalData) { + Model model("TestAddInitializedOrtValue", false, *logger_); + Graph& graph = model.MainGraph(); + + // Create a TensorProto with external data reference + const std::string external_data_init = "external_data_init"; + auto allocator = CPUAllocator::DefaultInstance(); + constexpr const int64_t kTensorSize = 256; + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(external_data_init, allocator, kTensorSize, tensor_proto, ort_value); + + // Test adding the initialized OrtValue with external data reference + ASSERT_STATUS_OK(graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // Verify the initializer was added correctly + OrtValue retrieved_value; + ASSERT_TRUE(graph.GetOrtValueInitializer(external_data_init, retrieved_value, false)); + + // Verify the tensor data + const Tensor& retrieved_tensor = retrieved_value.Get(); + ASSERT_EQ(retrieved_tensor.Shape().Size(), kTensorSize); + ASSERT_EQ(retrieved_tensor.DataType(), DataTypeImpl::GetType()); + + // Verify the TensorProto was also added and has external data location + const TensorProto* retrieved_proto = nullptr; + ASSERT_TRUE(graph.GetInitializedTensor(external_data_init, retrieved_proto)); + ASSERT_NE(retrieved_proto, nullptr); + ASSERT_EQ(retrieved_proto->name(), external_data_init); + ASSERT_TRUE(utils::HasExternalDataInMemory(tensor_proto)); +} + +TEST_F(GraphTest, AddInitializedOrtValueMismatch) { + Model model("TestAddInitializedOrtValue_Mismatch", false, *logger_); + Graph& graph = model.MainGraph(); + + // Create a TensorProto with external data reference + const std::string name = "init"; + constexpr const int64_t kTensorSize = 256; + auto allocator = CPUAllocator::DefaultInstance(); + TensorProto tensor_proto; + OrtValue ort_value; + TensorShape shape({kTensorSize}); + CreateIntializerWithDataInMemory(name, allocator, kTensorSize, tensor_proto, ort_value); + + OrtValue ort_value_diff; + // Now try to create a value that has a different data type + Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, allocator, ort_value_diff); + Status status = graph.AddInitializedOrtValue(tensor_proto, ort_value_diff); + ASSERT_FALSE(status.IsOK()); + + // Create OrtValue with different shape [2] + TensorShape diff_shape({2}); + Tensor::InitOrtValue(DataTypeImpl::GetType(), diff_shape, allocator, ort_value_diff); + + // Fails on shape mismatch + status = graph.AddInitializedOrtValue(tensor_proto, ort_value_diff); + ASSERT_FALSE(status.IsOK()); +} + +TEST_F(GraphTest, AddInitializedOrtValueDuplicate) { + Model model("TestAddInitializedOrtValue_Duplicate", false, *logger_); + Graph& graph = model.MainGraph(); + + // Create a TensorProto with external data reference + const std::string name = "init"; + constexpr const int64_t kTensorSize = 256; + auto allocator = CPUAllocator::DefaultInstance(); + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(name, allocator, kTensorSize, tensor_proto, ort_value); + + // Add the first initializer successfully + ASSERT_STATUS_OK(graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // try again + Status status = graph.AddInitializedOrtValue(tensor_proto, ort_value); + ASSERT_FALSE(status.IsOK()); +} + #ifdef ENABLE_TRAINING TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_Recompute) { diff --git a/onnxruntime/test/onnx/microbenchmark/activation.cc b/onnxruntime/test/onnx/microbenchmark/activation.cc index df36135bd3017..65a0b2b93b4d2 100644 --- a/onnxruntime/test/onnx/microbenchmark/activation.cc +++ b/onnxruntime/test/onnx/microbenchmark/activation.cc @@ -24,7 +24,7 @@ extern OrtEnv* env; class Allocs : public IExecutionProvider { private: - std::shared_ptr alloc = std::make_shared(); + AllocatorPtr alloc = CPUAllocator::DefaultInstance(); public: Allocs() : IExecutionProvider("fake") {}; diff --git a/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc b/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc index ed1dc808871ec..ec81830156381 100644 --- a/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc +++ b/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc @@ -12,38 +12,35 @@ using namespace onnxruntime; template void SetRandom(Tensor& input) { - int64_t size = input.Shape().Size(); std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution distr(0, 1); - T* data = input.MutableData(); - for (int64_t i = 0; i != size; ++i) { - data[i] = distr(gen); - } + auto span = input.MutableDataAsSpan(); + std::generate(span.begin(), span.end(), [&]() { return distr(gen); }); } static void BM_BatchNormOldEigen(benchmark::State& state) { - std::shared_ptr alloc = std::make_shared(); + AllocatorPtr alloc = CPUAllocator::DefaultInstance(); const int64_t batch_size = state.range(0); const TensorShape shape = {batch_size, 64, 75, 75}; using T = float; - Tensor* X = new Tensor(DataTypeImpl::GetType(), shape, alloc); - SetRandom(*X); - const TensorShape& x_shape = X->Shape(); - Tensor* Y = new Tensor(DataTypeImpl::GetType(), shape, alloc); - Tensor* scale = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*scale); - Tensor* mean = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*mean); + Tensor X(DataTypeImpl::GetType(), shape, alloc); + SetRandom(X); + const TensorShape& x_shape = X.Shape(); + Tensor Y(DataTypeImpl::GetType(), shape, alloc); + Tensor scale(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(scale); + Tensor mean(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(mean); - Tensor* B = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*B); + Tensor B(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(B); - Tensor* var = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*var); + Tensor var(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(var); bool is_spatial_ = true; double epsilon_ = 1e-5; @@ -60,26 +57,26 @@ static void BM_BatchNormOldEigen(benchmark::State& state) { // calculate sample_size (including all channels) size_t sample_size_incl_all_channels = sample_size * C; for (auto _ : state) { - ConstEigenVectorArrayMap scale_arr(scale->Data(), is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap bias_arr(B->Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap scale_arr(scale.Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap bias_arr(B.Data(), is_spatial_ ? C : sample_size_incl_all_channels); // Regardless of training or testing, we will apply the estimated mean // and standard deviation to the input. For testing, they are // specified directly by the input, and for training, they are computed // by the op. Eigen::Array inv_std(is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap var_arr(var->Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap var_arr(var.Data(), is_spatial_ ? C : sample_size_incl_all_channels); inv_std = (var_arr + epsilon_).sqrt().inverse(); - ConstEigenVectorArrayMap mean_arr(mean->Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap mean_arr(mean.Data(), is_spatial_ ? C : sample_size_incl_all_channels); // We can fuse the output computation as follows: // ((x - est_mean) * (inv_var) * scale + bias // to // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) Eigen::Array new_scale = inv_std * scale_arr; Eigen::Array new_bias = bias_arr - mean_arr * new_scale; - EigenArrayMap Y_arr(Y->MutableData(), is_spatial_ ? sample_size : sample_size_incl_all_channels, + EigenArrayMap Y_arr(Y.MutableData(), is_spatial_ ? sample_size : sample_size_incl_all_channels, is_spatial_ ? N * C : N); - ConstEigenArrayMap X_arr(X->Data(), is_spatial_ ? sample_size : sample_size_incl_all_channels, + ConstEigenArrayMap X_arr(X.Data(), is_spatial_ ? sample_size : sample_size_incl_all_channels, is_spatial_ ? N * C : N); if (is_spatial_) { // spatial == 1 for (size_t nc = 0; nc < N * C; ++nc) { diff --git a/onnxruntime/test/onnx/microbenchmark/main.cc b/onnxruntime/test/onnx/microbenchmark/main.cc index 70faa6f11989d..24d02caa96aa1 100644 --- a/onnxruntime/test/onnx/microbenchmark/main.cc +++ b/onnxruntime/test/onnx/microbenchmark/main.cc @@ -27,7 +27,7 @@ OrtEnv* env = nullptr; using namespace onnxruntime; static void BM_CPUAllocator(benchmark::State& state) { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); const size_t len = state.range(0); for (auto _ : state) { void* p = cpu_allocator->Alloc(len); diff --git a/onnxruntime/test/onnx/microbenchmark/resize.cc b/onnxruntime/test/onnx/microbenchmark/resize.cc index 020680c12b8f5..2ccf588dfef54 100644 --- a/onnxruntime/test/onnx/microbenchmark/resize.cc +++ b/onnxruntime/test/onnx/microbenchmark/resize.cc @@ -30,7 +30,7 @@ static void BM_NhwcUpsampleBilinear(benchmark::State& state) { const T* const XdataBase = GenerateArrayWithRandomValue(XdataBaseSize, std::numeric_limits::min(), std::numeric_limits::max()); const size_t YdataBaseSize = batch_size * num_channels * output_height * output_width; T* const YdataBase = (T*)aligned_alloc(sizeof(T) * YdataBaseSize, 64); - AllocatorPtr alloc = std::make_shared(); + AllocatorPtr alloc = CPUAllocator::DefaultInstance(); const GetOriginalCoordinateFunc& get_original_coordinate = [](float x_resized, float x_scale, float, float, float, float) { return x_resized / x_scale; diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index 92c4b5bc88fe7..0b4ec1bab192a 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -3,21 +3,22 @@ #include "tensorprotoutils.h" -#include #include #include +#include #include -#include "mem_buffer.h" +#include "callback.h" +#include "core/common/make_string.h" #include "core/common/safeint.h" #include "core/common/status.h" -#include "core/common/make_string.h" +#include "core/framework/allocator.h" #include "core/framework/data_types.h" #include "core/framework/endian.h" -#include "core/framework/allocator.h" -#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/endian_utils.h" #include "core/graph/onnx_protobuf.h" -#include "callback.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "mem_buffer.h" struct OrtStatus { OrtErrorCode code; @@ -69,21 +70,13 @@ static void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_length ORT_CXX_API_THROW(MakeString("UnpackTensor: the pre-allocated size does not match the raw data size, expected ", expected_size_in_bytes, ", got ", raw_data_length), OrtErrorCode::ORT_FAIL); - memcpy(p_data, raw_data, raw_data_length); - if constexpr (endian::native != endian::little) { - /* Convert Endianness */ - char* bytes = reinterpret_cast(p_data); - size_t element_size = sizeof(T); - size_t num_elements = raw_data_length / element_size; - - for (size_t i = 0; i < num_elements; ++i) { - char* start_byte = bytes + i * element_size; - char* end_byte = start_byte + element_size - 1; - /* keep swapping */ - for (size_t count = 0; count < element_size / 2; ++count) { - std::swap(*start_byte++, *end_byte--); - } - } + + /* Convert Endianness */ + if constexpr (endian::native != endian::little && sizeof(T) > 1) { + utils::SwapByteOrderCopy(sizeof(T), gsl::make_span(reinterpret_cast(raw_data), raw_data_length), + gsl::make_span(reinterpret_cast(p_data), raw_data_length)); + } else { + memcpy(p_data, raw_data, raw_data_length); } } diff --git a/onnxruntime/test/opaque_api/test_opaque_api.cc b/onnxruntime/test/opaque_api/test_opaque_api.cc index c1c98cbf0ff5b..5bccf5ab1ac0d 100644 --- a/onnxruntime/test/opaque_api/test_opaque_api.cc +++ b/onnxruntime/test/opaque_api/test_opaque_api.cc @@ -73,7 +73,7 @@ struct NonTensorTypeConverter { // Create and populate Tensor TensorShape shape({1}); - std::shared_ptr allocator = std::make_shared(); + std::shared_ptr allocator = CPUAllocator::DefaultInstance(); std::unique_ptr tp(new Tensor(DataTypeImpl::GetType(), shape, allocator)); *tp->MutableData() = input.Get().str_; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 35d50cbec678f..a6a5004a2e2e2 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1431,7 +1431,7 @@ TEST_F(GraphTransformationTests, FusePadWithConv) { auto& node = *graph.GetNode(node_index); if (node.OpType() == "Pad") { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; gsl::span pads_values = pads.DataAsSpan(); expected_pads.resize(pads_values.size() - 4); @@ -1484,7 +1484,7 @@ TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) { auto& node = *graph.GetNode(node_index); if (node.OpType() == "Pad") { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; gsl::span pads_values = pads.DataAsSpan(); expected_pads.resize(pads_values.size() - 4); @@ -1532,7 +1532,7 @@ TEST_F(GraphTransformationTests, FusePadWithMaxPool) { auto& node = *graph.GetNode(node_index); if (node.OpType() == "Pad") { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; gsl::span pads_values = pads.DataAsSpan(); expected_pads.resize(pads_values.size() - 4); @@ -3804,11 +3804,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 4U); + EXPECT_EQ(initializer.size(), 4U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], 12); @@ -3840,11 +3840,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionOneConstTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], 768); @@ -3875,11 +3875,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalNodeIsOutput) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -3911,11 +3911,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 5U); + EXPECT_EQ(initializer.size(), 5U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 128); EXPECT_EQ(val[2], 0); @@ -3970,11 +3970,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerSubgrap const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 1); EXPECT_EQ(val[1], 200); EXPECT_EQ(val[2], -1); @@ -4003,11 +4003,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerApplies const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 1); EXPECT_EQ(val[1], 200); EXPECT_EQ(val[2], 0); @@ -4073,11 +4073,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphMultipleOutputs) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4107,11 +4107,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraph) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4141,11 +4141,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4211,11 +4211,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphWithDiv) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4247,11 +4247,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphWithMul) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4281,11 +4281,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionDistilBertTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 4U); + EXPECT_EQ(initializer.size(), 4U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], -1); EXPECT_EQ(val[2], 2); @@ -4476,8 +4476,8 @@ static void ValidateAttention(Graph& graph) { ASSERT_TRUE(tensor_proto != nullptr); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer->size(), 192U); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); + EXPECT_EQ(initializer.size(), 192U); // Validate two rows (2x24 items) for sanity check. std::vector expected_value = { @@ -4531,7 +4531,7 @@ static void ValidateAttention(Graph& graph) { -0.0101165771484375, -0.00490570068359375}; - const float* data = initializer->data(); + const float* data = initializer.data(); for (size_t i = 0; i < expected_value.size(); i++) { EXPECT_EQ(data[i], static_cast(expected_value[i])); } @@ -4540,8 +4540,8 @@ static void ValidateAttention(Graph& graph) { ASSERT_TRUE(tensor_proto != nullptr); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - auto initializer2 = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer2->size(), 24U); + Initializer initializer2(graph, *tensor_proto, graph.ModelPath()); + EXPECT_EQ(initializer2.size(), 24U); std::vector expected_value2 = { -0.23681640625, @@ -4569,7 +4569,7 @@ static void ValidateAttention(Graph& graph) { 0.0535888671875, 0.0091094970703125}; - const float* data2 = initializer2->data(); + const float* data2 = initializer2.data(); for (size_t i = 0; i < expected_value2.size(); i++) { EXPECT_EQ(data2[i], static_cast(expected_value2[i])); } @@ -7011,7 +7011,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatOrHalfTypedInitialize if (entry.first.compare(mul_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 1U); float float_const_value; if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { @@ -7134,7 +7134,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatOrHalfTypedInitiali if (entry.first.compare(mul_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 8U); for (int i = 0; i < 8; ++i) { float float_const_value; @@ -7240,7 +7240,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatAndHalfTypedInitializ for (const auto& entry : initialized_tensor_set) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; if (entry.first.compare(mul_initializer->Name()) == 0) { TEST_RETURN_IF_NOT(float_const.size() == 1U); TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); @@ -7369,7 +7369,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatAndHalfTypedInitial for (const auto& entry : initialized_tensor_set) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 8U); if (entry.first.compare(mul_initializer->Name()) == 0) { TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); @@ -7507,13 +7507,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareIntMaxOrFloatInfinityIniti for (const auto& entry : initialized_tensor_set) { if (entry.first.compare(mul_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer int64_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer int64_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(int64_const.size() == 1U); int64_t int64_const_value = *(int64_const.data()); TEST_RETURN_IF_NOT(int64_const_value == std::numeric_limits::max()); } else if (entry.first.compare(sub_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 1U); float float_const_value = *(float_const.data()); TEST_RETURN_IF_NOT(float_const_value == std::numeric_limits::infinity()); @@ -7606,13 +7606,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { for (const auto& entry : initialized_tensor_set) { if (entry.first.compare("y_scale") == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer int64_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer int64_const{graph, *tensor_proto, graph.ModelPath()}; ASSERT_TRUE(int64_const.size() == 1U); float float_const_value = *(int64_const.data()); ASSERT_TRUE(float_const_value == 1); } else { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer uint8_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer uint8_const{graph, *tensor_proto, graph.ModelPath()}; ASSERT_TRUE(uint8_const.size() == 1U); uint8_t uint8_const_value = *(uint8_const.data()); ASSERT_TRUE(uint8_const_value == static_cast(1)); @@ -7688,7 +7688,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); } @@ -7828,7 +7828,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) { const NodeArg& input_arg = *(node.InputDefs()[1]); const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(1 == static_cast(*(init_const.data()))); } @@ -7881,7 +7881,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) { const NodeArg& input_arg = *(node.InputDefs()[1]); const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); } @@ -8051,7 +8051,7 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32); TEST_RETURN_IF_NOT(2 == *(init_const.data())); } @@ -8090,7 +8090,7 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); } diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 7b700922f4306..627a68f38b585 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -882,19 +882,18 @@ static void EmbedLayerNormFusionFormat5(const std::basic_string& file EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); // Validate the position embedding input. + double expected_value[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0}; for (const Node& node : graph.Nodes()) { if (node.OpType() == "EmbedLayerNormalization") { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[3]->Name()); ASSERT_TRUE(tensor_proto != nullptr); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer->size(), 12U); + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; + EXPECT_EQ(initializer.size(), std::size(expected_value)); - std::vector expected_value = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0}; - - const float* data = initializer->data(); - for (size_t i = 0; i < expected_value.size(); i++) { + const float* data = initializer.data(); + for (size_t i = 0; i < std::size(expected_value); i++) { EXPECT_EQ(data[i], static_cast(expected_value[i])); } } diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index a2c9881ab5169..2449f7c962e83 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -137,7 +137,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) { std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_y); @@ -169,7 +169,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { std::vector dims_mul_x = {3, 2, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; @@ -198,7 +198,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { std::vector dims_mul_x = {3, 2, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; @@ -338,7 +338,7 @@ TEST(CoreMLExecutionProviderTest, TestModelCache) { std::vector dims_mul_x = {3, 2, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 1af7bdea68b67..a96c8d05ee64f 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -52,7 +52,7 @@ TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) { std::vector dims_mul_y = {3, 2, 2}; std::vector values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; @@ -80,7 +80,7 @@ TEST(NnapiExecutionProviderTest, SigmoidSupportedInputRankTest) { std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(std::move(cpu_allocator), dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; @@ -107,7 +107,7 @@ TEST(NnapiExecutionProviderTest, DynamicGraphInputTest) { std::vector dims_mul_x = {1, 1, 4, 4}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(std::move(cpu_allocator), dims_mul_x, values_mul_x, &ml_value_x); @@ -138,7 +138,7 @@ TEST(NnapiExecutionProviderTest, InternalUint8SupportTest) { std::vector dims_x = {1, 1, 1, 3}; std::vector values_x = {0.0f, 256.0f, 512.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(std::move(cpu_allocator), dims_x, values_x, &ml_value_x); NameMLValMap feeds; @@ -195,7 +195,7 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { std::vector dims_mul_x = {1, 1, 3, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; @@ -522,7 +522,7 @@ TEST(NnapiExecutionProviderTest, SharedInitializersDoNotGetSkipped) { constexpr auto* model_file_name = ORT_TSTR("testdata/clip_div_shared_initializer.onnx"); #if defined(__ANDROID__) - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); std::vector x_dims{3, 2}; std::vector x_values(3.0f, 3 * 2); diff --git a/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc b/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc index 63a672615c27b..fffd081a692c0 100644 --- a/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc +++ b/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc @@ -60,7 +60,7 @@ TEST(RknpuExecutionProviderTest, FunctionTest) { std::vector dims_mul_x = {1, 1, 3, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); OrtValue ml_value_x; CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index ddbcfd4931835..553059932db90 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -662,13 +662,10 @@ TEST(TensorrtExecutionProviderTest, ExcludeOpsTest) { params.trt_engine_cache_enable = 1; params.trt_op_types_to_exclude = "MaxPool"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - auto status = session_object.Load(model_name); - ASSERT_TRUE(status.IsOK()); - status = session_object.Initialize(); - ASSERT_TRUE(status.IsOK()); - status = session_object.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider))); + ASSERT_STATUS_OK(session_object.Load(model_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); std::vector engine_files; engine_files = GetCachesByType("./", ".engine"); diff --git a/orttraining/orttraining/core/graph/graph_augmenter.cc b/orttraining/orttraining/core/graph/graph_augmenter.cc index 19b200efcf6bb..1fde22b32451b 100644 --- a/orttraining/orttraining/core/graph/graph_augmenter.cc +++ b/orttraining/orttraining/core/graph/graph_augmenter.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/graph/graph_utils.h" #include "orttraining/core/graph/graph_augmenter.h" #include "core/common/logging/logging.h" diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index a4143e7c817fd..888664dda8806 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -529,7 +529,7 @@ Status TransformGraphForMixedPrecision(Graph& graph, // Add new FP16/BFloat16 initializers to the graph for (const auto& kv : mixed_precision_initializers) { const ONNX_NAMESPACE::TensorProto* tensor_proto = kv.second; - Initializer initializer(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); ONNX_NAMESPACE::TensorProto weight_tensor_proto = mixed_precision_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ? initializer.ToFP16(kv.first) : initializer.ToBFloat16(kv.first); graph.AddInitializedTensor(weight_tensor_proto); } diff --git a/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc b/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc index 3f4034a9db222..8c9072614a4b0 100644 --- a/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc +++ b/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc @@ -125,8 +125,8 @@ static std::vector AddPartitionsForParameter( ORT_ENFORCE(dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); // Find the initializer partition to read out. - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* initializer_data = initializer->data(); + auto initializer = Initializer{graph, *tensor_proto, graph.ModelPath()}; + const float* initializer_data = initializer.data(); // Create new initializer tensor proto. ONNX_NAMESPACE::TensorProto initializer_partition; diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index ff220fcb067b8..90be9e24d3dd4 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -121,7 +121,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); - auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); + auto& axes_input = graph_utils::AddInitializerWithExternalData(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 25e16304789b6..55286379fd273 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -171,8 +171,8 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node LOGS_DEFAULT(WARNING) << "Checkpointing is not currently supported for graphs requiring partitioning of weight with stride > 1"; } - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* a_weight = initializer->data(); + auto initializer = Initializer{graph, *tensor_proto, graph.ModelPath()}; + const float* a_weight = initializer.data(); std::string new_initializer_name = original_name + "_column_rank_" + std::to_string(horizontal_parallel_rank_); @@ -306,8 +306,8 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg << horizontal_parallel_size_ << ", not supported currently."; return false; } - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* a_weight = initializer->data(); + auto initializer = Initializer{graph, *tensor_proto, graph.ModelPath()}; + const float* a_weight = initializer.data(); std::string new_initializer_name = original_name + "_row_rank_" + std::to_string(horizontal_parallel_rank_); @@ -453,15 +453,15 @@ Status MegatronTransformer::TransformGPT2MLP(Graph& graph, bool& modified, return skip_status; } - NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition); + NodeArg& a_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, a_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg); updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()}); - NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition); + NodeArg& a_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, a_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg); updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()}); - NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition); + NodeArg& b_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, b_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg); updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()}); @@ -600,15 +600,15 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, return skip_status; } - NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition); + NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wi_weight_initializer_partition); graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg); updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()}); - NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition); + NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wi_bias_initializer_partition); graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg); updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()}); - NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition); + NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wo_weight_initializer_partition); graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg); updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()}); @@ -787,15 +787,15 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, // The number of the values should be more than 2, and the 3rd value should be divisible by parallel size, // i.e., the attention head number should be divisible by parallel size. - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - if (init_const->size() != 3 && init_const->size() != 4) { + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + if (init_const.size() != 3 && init_const.size() != 4) { is_reshape_valid = false; break; } - const int64_t* val = init_const->data(); + const int64_t* val = init_const.data(); if (val[2] % horizontal_parallel_size_ != 0) { - LOGS_DEFAULT(WARNING) << (init_const->size() == 3 ? "Hidden size " : "Number of attention heads ") << val[2] + LOGS_DEFAULT(WARNING) << (init_const.size() == 3 ? "Hidden size " : "Number of attention heads ") << val[2] << " is not divisible by horizontal_parallel_size_ " << horizontal_parallel_size_ << ", not supported currently."; is_reshape_valid = false; @@ -814,15 +814,15 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, [](Node* node_ptr) { return node_ptr != nullptr; }); // Replace by the partition weights. - NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg); updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()}); - NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg); updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()}); - NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg); updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); @@ -836,9 +836,9 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, const ONNX_NAMESPACE::TensorProto* tensor; graph.GetInitializedTensor(shape_arg->Name(), tensor); auto data_type = tensor->data_type(); - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - const int64_t* val = init_const->data(); - int64_t size = init_const->size(); + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + const int64_t* val = init_const.data(); + int64_t size = init_const.size(); ONNX_NAMESPACE::TensorProto tensor_partition; tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name())); tensor_partition.set_data_type(data_type); @@ -849,7 +849,7 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, val_partition.insert(val_partition.end(), val, val + size); val_partition[2] /= horizontal_parallel_size_; tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); - NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); + NodeArg& node_arg_partition = graph_utils::AddInitializerWithExternalData(graph, tensor_partition); graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); graph.RemoveInitializedTensor(shape_arg->Name()); } @@ -1068,12 +1068,12 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, } // The number of the values should be more than idx, and the idx'th value should be divisible by parallel size, // i.e., the attention head number should be divisible by parallel size. - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - if (init_const->size() <= idx) { + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + if (init_const.size() <= idx) { is_reshape_valid = false; break; } - const int64_t* val = init_const->data(); + const int64_t* val = init_const.data(); if (val[idx] % horizontal_parallel_size_ != 0) { LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx] << " is not divisible by horizontal_parallel_size_ " @@ -1130,7 +1130,7 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, size_t i = 0; for (auto trans_ptr : weight_transpose_node_ptrs) { auto weight_name = trans_ptr->MutableInputDefs()[0]->Name(); - NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_weight_initializer_partitions[i]); graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg); graph.RemoveInitializedTensor(weight_name); updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()}); @@ -1139,14 +1139,14 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, i = 0; for (auto add_ptr : bias_add_node_ptrs) { auto bias_name = add_ptr->MutableInputDefs()[1]->Name(); - NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_bias_initializer_partitions[i]); graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg); graph.RemoveInitializedTensor(bias_name); updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()}); i++; } - NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg); graph.RemoveInitializedTensor(dense_weight_arg->Name()); updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); @@ -1162,11 +1162,12 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, int64_t idx = x.second; auto shape_arg = node_ptr->MutableInputDefs()[1]; const ONNX_NAMESPACE::TensorProto* tensor; - graph.GetInitializedTensor(shape_arg->Name(), tensor); + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(shape_arg->Name(), tensor), + "Expecting initializer present: ", shape_arg->Name()); auto data_type = tensor->data_type(); - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - const int64_t* val = init_const->data(); - int64_t size = init_const->size(); + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + const int64_t* val = init_const.data(); + int64_t size = init_const.size(); ONNX_NAMESPACE::TensorProto tensor_partition; tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name())); tensor_partition.set_data_type(data_type); @@ -1177,7 +1178,7 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, val_partition.insert(val_partition.end(), val, val + size); val_partition[idx] /= horizontal_parallel_size_; tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); - NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); + NodeArg& node_arg_partition = graph_utils::AddInitializerWithExternalData(graph, tensor_partition); graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); graph.RemoveInitializedTensor(shape_arg->Name()); } diff --git a/orttraining/orttraining/core/optimizer/qdq_fusion.cc b/orttraining/orttraining/core/optimizer/qdq_fusion.cc index fc9a6d213f794..4a5bdc1f8fcd2 100644 --- a/orttraining/orttraining/core/optimizer/qdq_fusion.cc +++ b/orttraining/orttraining/core/optimizer/qdq_fusion.cc @@ -21,22 +21,31 @@ int ReplaceOrCreateZeroPointInitializer(Graph& graph, Node& quantize_node) { ONNX_NAMESPACE::TensorProto zero_point_tensor_float; if (quant_node_input_defs.size() >= 3) { // The quantize node has the zero point input - auto zero_point_tensor_int = graph.GetInitializer(quant_node_input_defs[2]->Name(), true); - ORT_ENFORCE(zero_point_tensor_int != nullptr, "Expected: zero point initializer with name ", + constexpr const bool check_outer_scope_true = true; + const auto* zero_point_tensor_proto = graph.GetInitializer(quant_node_input_defs[2]->Name(), check_outer_scope_true); + ORT_ENFORCE(zero_point_tensor_proto != nullptr, "Expected: zero point initializer with name ", quant_node_input_defs[2]->Name(), " to be present in the graph. Actual: not found."); - zero_point_type = zero_point_tensor_int->data_type(); - zero_point_tensor_float.set_name(graph.GenerateNodeArgName(zero_point_tensor_int->name())); + Initializer zero_point_tensor_int(graph, *zero_point_tensor_proto, graph.ModelPath(), check_outer_scope_true); + zero_point_type = zero_point_tensor_int.data_type(); + zero_point_tensor_float.set_name(graph.GenerateNodeArgName(zero_point_tensor_int.name())); zero_point_tensor_float.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - for (const auto val : zero_point_tensor_int->int32_data()) { - zero_point_tensor_float.add_float_data(static_cast(val)); + if (zero_point_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + for (const auto val : zero_point_tensor_int.DataAsSpan()) { + zero_point_tensor_float.add_float_data(static_cast(val)); + } + } else { + for (const auto val : zero_point_tensor_int.DataAsSpan()) { + zero_point_tensor_float.add_float_data(static_cast(val)); + } } - for (const auto& dim : zero_point_tensor_int->dims()) { + for (const auto dim : zero_point_tensor_int.dims()) { zero_point_tensor_float.add_dims(dim); } - graph.RemoveInitializedTensor(zero_point_tensor_int->name()); + graph.RemoveInitializedTensor(zero_point_tensor_int.name()); // Since the quantize node has the zero point initializer input, replace it - graph_utils::ReplaceNodeInput(quantize_node, 2, graph_utils::AddInitializer(graph, zero_point_tensor_float)); + graph_utils::ReplaceNodeInput(quantize_node, 2, + graph_utils::AddInitializerWithExternalData(graph, zero_point_tensor_float)); } else { // The quantize node does not have the zero point optional input. // Create the zero point initializer to be 0. @@ -45,7 +54,8 @@ int ReplaceOrCreateZeroPointInitializer(Graph& graph, Node& quantize_node) { zero_point_tensor_float.add_float_data(0.0f); // Since the input did not exist, add the newly created initializer as an input - graph_utils::AddNodeInput(quantize_node, 2, graph_utils::AddInitializer(graph, zero_point_tensor_float)); + graph_utils::AddNodeInput(quantize_node, 2, + graph_utils::AddInitializerWithExternalData(graph, zero_point_tensor_float)); } return zero_point_type; diff --git a/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc b/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc index e719a21118028..e6319952dfae7 100644 --- a/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc +++ b/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc @@ -74,7 +74,7 @@ bool IsScaleOperator(Graph& graph, Node& node, return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { const MLFloat16* val = init_const.data(); diff --git a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc index 84bf715c7c85a..8c9c12ceb4497 100644 --- a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc +++ b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc @@ -83,7 +83,7 @@ Status SceLossGradBiasFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ ignore_index_initializer_proto.set_name(graph.GenerateNodeArgName("sce_grad_ignore_index")); ignore_index_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); ignore_index_initializer_proto.add_int64_data(static_cast(-1)); - new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializer(graph, ignore_index_initializer_proto)); + new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializerWithExternalData(graph, ignore_index_initializer_proto)); } new_scegrad_node_inputs.emplace_back(bias_def); if (!p_reshape) { diff --git a/orttraining/orttraining/core/optimizer/triton_fusion.cc b/orttraining/orttraining/core/optimizer/triton_fusion.cc index f2cb3c2b8c6db..026f39712ffe6 100644 --- a/orttraining/orttraining/core/optimizer/triton_fusion.cc +++ b/orttraining/orttraining/core/optimizer/triton_fusion.cc @@ -64,7 +64,7 @@ bool CheckAxes(const Graph& graph, const Node& node, bool single_axis, const std if (!axes_const) { return false; } - Initializer initializer{*axes_const, graph.ModelPath()}; + Initializer initializer{graph, *axes_const, graph.ModelPath()}; axes_values.insert(axes_values.end(), initializer.DataAsSpan().begin(), initializer.DataAsSpan().end()); } else { diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index b03f1b1eadb3b..650ed69578210 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -4,9 +4,11 @@ #include "orttraining/core/session/training_session.h" #include "core/framework/data_transfer_utils.h" +#include "core/graph/graph_utils.h" #include "core/graph/model.h" #include "core/graph/model_saving_options.h" #include "core/session/IOBinding.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -977,22 +979,18 @@ static Status UpdateWeightsBeforeSaving( if (!graph.GetInitializedTensor(name_and_ml_value.first, old_tensor_proto)) { continue; } - ONNX_NAMESPACE::TensorProto new_tensor_proto = *old_tensor_proto; - if (new_tensor_proto.has_raw_data()) { - auto* const raw_data = new_tensor_proto.mutable_raw_data(); - auto dst_span = gsl::make_span(&(*raw_data)[0], raw_data->size()); - ORT_RETURN_IF_ERROR(CopyTensorDataToByteSpan( - data_transfer_manager, src_tensor, cpu_alloc_info, dst_span)); - } else { - ORT_ENFORCE(new_tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT); - auto* const float_data = new_tensor_proto.mutable_float_data(); - auto dst_span = gsl::make_span(float_data->mutable_data(), float_data->size()); - ORT_RETURN_IF_ERROR(CopyTensorDataToSpan( - data_transfer_manager, src_tensor, cpu_alloc_info, dst_span)); - } + + Initializer initializer{graph, *old_tensor_proto, graph.ModelPath()}; + const auto chars_span = ReinterpretAsSpan(initializer.MutableDataAsByteSpan()); + ORT_RETURN_IF_ERROR(CopyTensorDataToByteSpan( + data_transfer_manager, src_tensor, cpu_alloc_info, chars_span)); + + TensorProto new_tensor_proto; + OrtValue ort_value; + initializer.ToProtoWithOrtValue(new_tensor_proto, ort_value); // Replace the TensorProto in the model. - ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(new_tensor_proto)); + ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(new_tensor_proto, ort_value)); } return Status::OK(); } diff --git a/orttraining/orttraining/test/framework/checkpointing_test.cc b/orttraining/orttraining/test/framework/checkpointing_test.cc index a7ee776b9bc39..615a3e86a8a2f 100644 --- a/orttraining/orttraining/test/framework/checkpointing_test.cc +++ b/orttraining/orttraining/test/framework/checkpointing_test.cc @@ -52,7 +52,7 @@ void CompareOrtValuesToTensorProtoValues( ASSERT_EQ(name_to_ort_value.size(), name_to_tensor_proto.size()); NameMLValMap name_to_ort_value_from_tensor_proto{}; - AllocatorPtr tmp_allocator = std::make_shared(); + AllocatorPtr tmp_allocator = CPUAllocator::DefaultInstance(); for (const auto& name_and_tensor_proto : name_to_tensor_proto) { const auto& name = name_and_tensor_proto.first; diff --git a/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc b/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc index 1cb9518a06193..9461f751aecdd 100644 --- a/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc +++ b/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc @@ -183,10 +183,13 @@ Status GetDataAndShapeFromTensorProto(const Graph& graph, const NodeArg* input_a } const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; - graph.GetInitializedTensor(input_arg->Name(), tensor_proto); - auto init_const = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* data_float = init_const->data(); - data.insert(data.end(), data_float, data_float + element_count); + if (!graph.GetInitializedTensor(input_arg->Name(), tensor_proto)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to get tensor proto for ", input_arg->Name()); + } + auto init_const = Initializer{graph, *tensor_proto, graph.ModelPath()}; + auto data_float = init_const.DataAsSpan(); + data.insert(data.end(), data_float.begin(), data_float.end()); return Status::OK(); } diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 60708b05626c5..9e12fdcd2bb53 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -108,6 +108,7 @@ Status TransformModelInputsForInference(Graph& inference_graph, ORT_ENFORCE(!inference_graph.IsInitializedTensor(named_parameter_it->first), "The eval graph is invalid. Expected model parameter ", named_parameter_it->first, " to be a graph input, not a graph initializer."); + inference_graph.AddInitializedTensor(utils::CopyTensorToTensorProto( named_parameter_it->second->Data().Get(), named_parameter_it->first, data_transfer_manager));