Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ inline std::vector<std::string> Session::GetInputNames() const {
size_t node_count = GetInputCount();
std::vector<std::string> out(node_count);
for (size_t i = 0; i < node_count; i++) {
char* tmp = GetInputName(i, allocator);
out[i] = tmp;
allocator.Free(tmp); // prevent memory leak
auto tmp = GetInputNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
Expand All @@ -47,9 +46,8 @@ inline std::vector<std::string> Session::GetOutputNames() const {
size_t node_count = GetOutputCount();
std::vector<std::string> out(node_count);
for (size_t i = 0; i < node_count; i++) {
char* tmp = GetOutputName(i, allocator);
out[i] = tmp;
allocator.Free(tmp); // prevent memory leak
auto tmp = GetOutputNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
Expand All @@ -59,9 +57,8 @@ inline std::vector<std::string> Session::GetOverridableInitializerNames() const
size_t init_count = GetOverridableInitializerCount();
std::vector<std::string> out(init_count);
for (size_t i = 0; i < init_count; i++) {
char* tmp = GetOverridableInitializerName(i, allocator);
out[i] = tmp;
allocator.Free(tmp); // prevent memory leak
auto tmp = GetOverridableInitializerNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
Expand Down
185 changes: 177 additions & 8 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ struct TypeInfo;
struct Value;
struct ModelMetadata;

namespace detail {
// Light functor to release memory with OrtAllocator
struct AllocatedFree {
OrtAllocator* allocator_;
explicit AllocatedFree(OrtAllocator* allocator)
: allocator_(allocator) {}
void operator()(void* ptr) const { if(ptr) allocator_->Free(allocator_, ptr); }
};
} // namespace detail

/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
* and release them at the end of the scope. The lifespan of the given allocator
* must eclipse the lifespan of AllocatedStringPtr instance
*/
using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;

/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.
Expand Down Expand Up @@ -385,13 +401,108 @@ struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API

char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
/** \deprecated use GetProducerNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName

/** \brief Returns a copy of the producer name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName

/** \deprecated use GetGraphNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName

/** \brief Returns a copy of the graph name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName

/** \deprecated use GetDomainAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain

/** \brief Returns a copy of the domain name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain

/** \deprecated use GetDescriptionAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription

/** \brief Returns a copy of the description.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription

/** \deprecated use GetGraphDescriptionAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription

/** \brief Returns a copy of the graph description.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription

/** \deprecated use GetCustomMetadataMapKeysAllocated()
* [[deprecated]]
* This interface produces multiple pointers that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys

std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys

/** \deprecated use LookupCustomMetadataMapAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap

/** \brief Looks up a value by a key in the Custom Metadata map
*
* \param zero terminated string key to lookup
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* maybe nullptr if key is not found.
*
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap

int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
};

Expand Down Expand Up @@ -436,12 +547,70 @@ struct Session : Base<OrtSession> {
size_t GetOutputCount() const; ///< Returns the number of model outputs
size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden

char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName
/** \deprecated use GetInputNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName

/** \brief Returns a copy of input name at the specified index.
*
* \param index must less than the value returned by GetInputCount()
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;

/** \deprecated use GetOutputNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName

/** \brief Returns a copy of output name at then specified index.
*
* \param index must less than the value returned by GetOutputCount()
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;

/** \deprecated use GetOverridableInitializerNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata

/** \brief Returns a copy of the overridable initializer name at then specified index.
*
* \param index must less than the value returned by GetOverridableInitializerCount()
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName

/** \deprecated use EndProfilingAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling

/** \brief Returns a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata

TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
Expand Down
Loading