Skip to content
Merged
70 changes: 58 additions & 12 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1286,9 +1286,15 @@ Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer,
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");

auto sys_ctx_handle_deleter = [&qnn_sys_interface = qnn_sys_interface_](void* handle) {
qnn_sys_interface.systemContextFree(reinterpret_cast<QnnSystemContext_Handle_t>(handle));
};

std::unique_ptr<void, decltype(sys_ctx_handle_deleter)> sys_ctx_handle_uptr(sys_ctx_handle, sys_ctx_handle_deleter);

const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
Qnn_ContextBinarySize_t binary_info_size{0};
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle_uptr.get(),
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated
static_cast<void*>(buffer),
buffer_length,
&binary_info,
Expand Down Expand Up @@ -1350,17 +1356,23 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
ORT_RETURN_IF(result, "Failed to get valid function pointer.");

void* bin_buffer = nullptr;
bool use_file_mapping = file_mapped_weights_enabled_;
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
// A nonzero buffer length implies an embedded context
if (file_mapped_weights_enabled_ && buffer_length == 0) {
ORT_RETURN_IF(!file_mapper_, "Attemping to use File Mapping feature but file_mapper_ is uninitialized");
if (use_file_mapping && buffer_length == 0) {
ORT_RETURN_IF(!file_mapper_, "Attempting to use File Mapping feature but file_mapper_ is uninitialized");

ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_length));

ORT_RETURN_IF(buffer_length == 0, "Context bin has a size of 0 bytes: ", context_bin_filepath);
ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &bin_buffer));

Comment thread
quic-calvnguy marked this conversation as resolved.
} else {
if (use_file_mapping) {
use_file_mapping = false;
LOGS(*logger_, WARNING) << "Node " << node_name << " is using an embedded cache."
<< " Disabling file mapping for this node.";
}
ORT_RETURN_IF(buffer == nullptr, "Attempting to load QNN context from buffer but buffer is null");
bin_buffer = static_cast<void*>(buffer);
}
Expand All @@ -1372,9 +1384,15 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");

auto sys_ctx_handle_deleter = [&qnn_sys_interface = qnn_sys_interface_](void* handle) {
qnn_sys_interface.systemContextFree(reinterpret_cast<QnnSystemContext_Handle_t>(handle));
};

std::unique_ptr<void, decltype(sys_ctx_handle_deleter)> sys_ctx_handle_uptr(sys_ctx_handle, sys_ctx_handle_deleter);

const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
Qnn_ContextBinarySize_t binary_info_size{0};
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle_uptr.get(),
bin_buffer,
buffer_length,
&binary_info,
Expand All @@ -1385,28 +1403,60 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
// Binary info to graph info
// retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");

#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
Qnn_Version_t blob_version = {0, 0, 0};
#endif

uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
graph_count = binary_info->contextBinaryInfoV1.numGraphs;
graphs_info = binary_info->contextBinaryInfoV1.graphs;
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
blob_version = binary_info->contextBinaryInfoV1.contextBlobVersion;
#endif
}
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 15) // starts from 2.22
else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
graph_count = binary_info->contextBinaryInfoV2.numGraphs;
graphs_info = binary_info->contextBinaryInfoV2.graphs;
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
blob_version = binary_info->contextBinaryInfoV2.contextBlobVersion;
#endif
}
#endif
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // starts from 2.28
else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
graph_count = binary_info->contextBinaryInfoV3.numGraphs;
graphs_info = binary_info->contextBinaryInfoV3.graphs;
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
blob_version = binary_info->contextBinaryInfoV3.contextBlobVersion;
#endif
}
#endif
else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version.");
}

#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
if (use_file_mapping) {
if (blob_version.major == 0 && blob_version.minor == 0 && blob_version.patch == 0) {
LOGS(*logger_, WARNING) << "Failed to retrieve context binary version for " << node_name << ". Disabling file mapping.";
Comment thread
quic-calvnguy marked this conversation as resolved.
Outdated
use_file_mapping = false;

Comment thread
quic-calvnguy marked this conversation as resolved.
Outdated
// Cannot use contextCreateFromBinaryWithCallback() unless context bin version is >= 3.3.3
} else if (blob_version.major < 3 ||
(blob_version.major == 3 && blob_version.minor < 3) ||
(blob_version.major == 3 && blob_version.minor == 3 && blob_version.patch < 3)) {
LOGS(*logger_, WARNING) << "Context binary of " << node_name << " is v" << std::to_string(blob_version.major) << "."
<< std::to_string(blob_version.minor) << "." << std::to_string(blob_version.patch)
<< ". File mapping is only supported for versions >= 3.3.3. Disabling file mapping for this node.";
use_file_mapping = false;
}
}
#endif

ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;

Expand Down Expand Up @@ -1452,7 +1502,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t

#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
Qnn_ContextBinaryCallback_t callbacks;
if (file_mapped_weights_enabled_ && file_mapper_) {
if (use_file_mapping && file_mapper_) {
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinaryWithCallback,
"Invalid function pointer for contextCreateFromBinaryWithCallback.");

Expand All @@ -1479,7 +1529,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t

#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
std::vector<char> backup_buffer;
if (file_mapped_weights_enabled_ && file_mapper_) {
if (use_file_mapping && file_mapper_) {
rt = qnn_interface_.contextCreateFromBinaryWithCallback(backend_handle_,
device_handle_,
context_configs,
Expand All @@ -1501,9 +1551,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
bin_buffer = static_cast<void*>(backup_buffer.data());
}
}
#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE

if (!file_mapped_weights_enabled_ || rt != QNN_SUCCESS) {
#endif
Comment thread
quic-calvnguy marked this conversation as resolved.
if (!use_file_mapping || rt != QNN_SUCCESS) {
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
context_configs,
Expand Down Expand Up @@ -1544,10 +1593,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
}
}

qnn_sys_interface_.systemContextFree(sys_ctx_handle);
sys_ctx_handle = nullptr;
context_created_ = true;

LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed.";
return Status::OK();
}
Expand Down
Loading