Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class QnnTensorWrapper {
dimensions_.assign(shape_data, shape_data + shape_rank);
SetQnnTensorDim(qnn_tensor_, dimensions_);

SetQnnTensorMemType(qnn_tensor_, QNN_TENSORMEMTYPE_RAW);
SetQnnTensorMemType(qnn_tensor_, GetQnnTensorMemType(qnn_tensor));

return Status::OK();
}
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,9 @@ Status QnnModelWrapper::MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensor
ORT_RETURN_IF_ERROR(UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor));
}

Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW;
if (true == model_settings_.htp_shared_memory && (IsGraphInput(tensor_name) || IsGraphOutput(tensor_name))) {
mem_type = QNN_TENSORMEMTYPE_MEMHANDLE;
}
tensor_wrapper = QnnTensorWrapper(tensor_name, GetTensorType(tensor_name), tensor_info.qnn_data_type,
std::move(tensor_info.quant_param), std::move(tensor_info.shape),
std::move(unpacked_tensor), mem_type);
std::move(unpacked_tensor));
return Status::OK();
}

Expand Down Expand Up @@ -105,6 +101,12 @@ bool QnnModelWrapper::AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper) {
return true;
}

Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW;
Comment thread
derdeljan-msft marked this conversation as resolved.
Outdated
if (true == model_settings_.htp_shared_memory && (IsGraphInput(tensor_name) || IsGraphOutput(tensor_name))) {
mem_type = QNN_TENSORMEMTYPE_MEMHANDLE;
}
SetQnnTensorMemType(tensor_wrapper.GetQnnTensor(), mem_type);

const Qnn_TensorType_t& qnn_tensor_type = tensor_wrapper.GetTensorType();
// save created tensors for later lookup to populate graph node construction
model_tensors_map_.emplace(tensor_name, std::move(tensor_wrapper));
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,11 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
enable_htp_shared_mem_allocator_ = ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map);
if (enable_htp_shared_mem_allocator_) {
// Initialize rpcmem_library_.
// This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available.
rpcmem_library_ = std::make_shared<qnn::RpcMemLibrary>();
// This library is only necessary for the inference (for the shared memory allocator), if we are in context
// generation stage, there is no need to load it as no allocations will be made.
if (!context_cache_enabled_) {
rpcmem_library_ = std::make_shared<qnn::RpcMemLibrary>();
}
model_settings_.htp_shared_memory = enable_htp_shared_mem_allocator_;
}

Expand Down
Loading