Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/vcpkg
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 18, 2025
2 parents b261d70 + a05203b commit 4fa0412
Show file tree
Hide file tree
Showing 84 changed files with 583 additions and 380 deletions.
6 changes: 3 additions & 3 deletions docs/python/_common/onnx_sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get_domain_list():
"""
Returns the list of available domains.
"""
return list(sorted(set(map(lambda s: s.domain, get_all_schemas_with_history()))))
return sorted({s.domain for s in get_all_schemas_with_history()})


def get_operator_schemas(op_name, version=None, domain=None):
Expand Down Expand Up @@ -779,9 +779,9 @@ def render(self, indent=""):
name = op["name"]
dom = self.domain.replace(".", "-")
table_dom.append(f" * - :ref:`l-onnx-doc{dom}-{name}`")
versions = list(reversed(sorted((k, v) for k, v in op["links"].items() if isinstance(k, int))))
versions = sorted(((k, v) for k, v in op["links"].items() if isinstance(k, int)), reverse=True)
col1 = ", ".join(f":ref:`{k} <{v}>`" for k, v in versions)
diffs = list(reversed(sorted((k, v) for k, v in op["links"].items() if isinstance(k, tuple))))
diffs = sorted(((k, v) for k, v in op["links"].items() if isinstance(k, tuple)), reverse=True)
col2 = ", ".join(f":ref:`{k[1]}/{k[0]} <{v}>`" for k, v in diffs)
table_dom.append(f" - {col1}")
table_dom.append(f" - {col2}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::vector<std::vector<NodeIndex>> IdenticalChildrenConsolidation::DivideIdenti
const Graph& graph,
Node* node,
const string_view& op) {
unordered_map<string_view, std::vector<NodeIndex>> identical_children_map;
unordered_map<std::string, std::vector<NodeIndex>> identical_children_map;
for (auto i = node->OutputEdgesBegin(); i != node->OutputEdgesEnd(); ++i) {
if (i->GetNode().OpType() == op) {
identical_children_map[IdentityBuilder(graph, i->GetNode())].push_back(i->GetNode().Index());
Expand Down Expand Up @@ -125,4 +125,4 @@ std::string IdenticalChildrenConsolidation::IdentityBuilder(const Graph& graph,

return identity.str();
}
} // namespace onnxruntime
} // namespace onnxruntime
342 changes: 281 additions & 61 deletions onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc

Large diffs are not rendered by default.

103 changes: 0 additions & 103 deletions onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ class SimpleOpBuilder : public BaseOpBuilder {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SimpleOpBuilder);

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand All @@ -53,91 +48,6 @@ class SimpleOpBuilder : public BaseOpBuilder {
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
};

// Move to qnn_utils if it's re-usable
Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
const std::string& convert_input_name,
const std::string& convert_output_name,
Qnn_DataType_t input_qnn_data_type,
Qnn_DataType_t output_qnn_data_type,
int32_t input_offset,
float input_scale,
const std::vector<uint32_t>& output_shape,
bool do_op_validation) {
// Assume input is already handled.
float qmin = 0.0f;
float qmax = 255.0f;
ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax));
double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin);
double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax);
float scale = 0.0f;
int32_t offset = 0;
ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast<float>(value_min),
static_cast<float>(value_max),
output_qnn_data_type,
scale,
offset));

std::vector<uint32_t> output_shape_copy = output_shape;
QnnTensorWrapper convert_output_tensorwrapper(convert_output_name,
QNN_TENSOR_TYPE_NATIVE,
output_qnn_data_type,
QnnQuantParamsWrapper(scale, offset),
std::move(output_shape_copy));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor.");

ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Convert",
{convert_input_name},
{convert_output_name},
{},
do_op_validation),
"Failed to add node.");
return Status::OK();
}

Status SimpleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const std::string& op_type = node_unit.OpType();
ORT_RETURN_IF_ERROR(BaseOpBuilder::ProcessInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation));

if (op_type == "MatMul") {
const auto& inputs = node_unit.Inputs();
TensorInfo input0_info = {};
TensorInfo input1_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info));
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input1_info));
// Need to insert Convert op if both inputs are dynamic inputs and are ufixed_16
if (!input0_info.is_initializer && !input1_info.is_initializer &&
input0_info.qnn_data_type == input1_info.qnn_data_type &&
input0_info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) {
ORT_RETURN_IF_NOT(input1_info.quant_param.IsPerTensor(),
"MatMul's activation inputs only support per-tensor quantization");
const Qnn_QuantizeParams_t& quant_param = input1_info.quant_param.Get();
// insert Convert op after input1
std::string convert_input_name = input_names.back();
input_names.pop_back();
const std::string& matmul_output_name = node_unit.Outputs()[0].node_arg.Name();
std::string convert_output_name = convert_input_name + "_convert_" + matmul_output_name;
ORT_RETURN_IF_ERROR(InsertConvertOp(qnn_model_wrapper,
convert_input_name,
convert_output_name,
input1_info.qnn_data_type,
QNN_DATATYPE_UFIXED_POINT_8,
quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
input1_info.shape,
do_op_validation));
input_names.push_back(convert_output_name);
}
}

return Status::OK();
}

Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit) const {
const std::string& op_type = node_unit.OpType();
Expand Down Expand Up @@ -378,19 +288,6 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
ORT_RETURN_IF(norm_p_order != 2, "QNN EP only supports LpNormalization with 'p' attribute equal to 2.");
}

if (op_type == "MatMul") {
Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT;
scalar_param.dataType = QNN_DATATYPE_BOOL_8;
scalar_param.bool8Value = 0;
QnnParamWrapper transpose_in0_param(node_unit.Index(), node_unit.Name(), QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, scalar_param);
param_tensor_names.push_back(transpose_in0_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(transpose_in0_param));

QnnParamWrapper transpose_in1_param(node_unit.Index(), node_unit.Name(), QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, scalar_param);
param_tensor_names.push_back(transpose_in1_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(transpose_in1_param));
}

if (op_type == "LeakyRelu") {
std::string input_name = "alpha";
ORT_RETURN_IF_ERROR(ProcessAlphaAttributeAsInput(qnn_model_wrapper, node_unit, input_name));
Expand Down
16 changes: 10 additions & 6 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1097,35 +1097,39 @@ Status QnnBackendManager::TerminateQnnLog() {
}

void QnnBackendManager::ReleaseResources() {
if (!backend_setup_completed_) {
return;
}

auto result = ReleaseContext();
if (Status::OK() != result) {
LOGS_DEFAULT(ERROR) << "Failed to ReleaseContext: " << result.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Failed to ReleaseContext.";
}

result = ReleaseProfilehandle();
if (Status::OK() != result) {
LOGS_DEFAULT(ERROR) << "Failed to ReleaseProfilehandle: " << result.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Failed to ReleaseProfilehandle.";
}

result = ReleaseDevice();
if (Status::OK() != result) {
LOGS_DEFAULT(ERROR) << "Failed to ReleaseDevice: " << result.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Failed to ReleaseDevice.";
}

result = ShutdownBackend();
if (Status::OK() != result) {
LOGS_DEFAULT(ERROR) << "Failed to ShutdownBackend: " << result.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Failed to ShutdownBackend.";
}

result = TerminateQnnLog();
if (Status::OK() != result) {
LOGS_DEFAULT(ERROR) << "Failed to TerminateQnnLog: " << result.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Failed to TerminateQnnLog.";
}

if (backend_lib_handle_) {
result = UnloadLib(backend_lib_handle_);
if (Status::OK() != result) {
LOGS_DEFAULT(ERROR) << "Failed to unload backend library: " << result.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Failed to unload backend library.";
}
}

Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ Status QnnModelWrapper::MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensor
return Status::OK();
}

Status QnnModelWrapper::MakeTensorWrapper(const TensorInfo& tensor_info,
const std::string& tensor_name,
QnnTensorWrapper& tensor_wrapper) const {
std::vector<uint8_t> unpacked_tensor;
if (tensor_info.is_initializer) {
ORT_RETURN_IF_ERROR(UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor));
}

tensor_wrapper = QnnTensorWrapper(tensor_name, GetTensorType(tensor_name), tensor_info.qnn_data_type,
tensor_info.quant_param.Copy(), std::vector<uint32_t>(tensor_info.shape),
std::move(unpacked_tensor));
return Status::OK();
}

bool QnnModelWrapper::AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper) {
// Keep a copy of tensor name sine it will be moved with the wrapper into model_tensors_map_
std::string tensor_name = tensor_wrapper.GetName();
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class QnnModelWrapper {

// Make a QnnTensorWrapper from an onnx input or output.
Status MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensorWrapper& tensor_wrapper) const;
Status MakeTensorWrapper(const TensorInfo& tensor_info,
const std::string& tensor_name,
QnnTensorWrapper& tensor_wrapper) const;

// Add to internal tensor wrapper table
bool AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper);
Expand Down
26 changes: 20 additions & 6 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,18 @@ InitializedTensorSet ModelBuilder::GetInitializerTensors() {
}

void ModelBuilder::PreprocessInitializers() {
const auto& initializers = graph_viewer_.GetAllInitializedTensors();
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));

// find all initializers consumed. AddInitializersToSkip will potentially decrement the usage count.
for (const auto* input : node->InputDefs()) {
if (input->Exists() && Contains(initializers, input->Name())) {
initializer_usage_[input->Name()]++;
}
}

if (const auto* op_builder = GetOpBuilder(*node)) {
op_builder->AddInitializersToSkip(*this, *node);
}
Expand All @@ -90,12 +99,11 @@ Status ModelBuilder::RegisterInitializers() {
const auto& name = tensor.name();
const auto& shape = tensor.dims();

// Ignore the following tensors:
// 1. Empty tensors: optional tensors can be indicated by an empty name.
// 2. Tensors in skipped_initializers_: These are tensors that are not used as WebNN Constants.
// Note: Scalar tensors are excluded because ONNX Runtime will optimize same scalar initializers into one.
if (name.empty() || (Contains(skipped_initializers_, name) && !shape.empty()))
// skip initializer if there is no remaining usage
auto usage_count = initializer_usage_[name];
if (usage_count == 0) {
continue;
}

std::vector<int32_t> dims;
// When the shape is empty, it is scalar initializer that dims = {};
Expand Down Expand Up @@ -385,7 +393,13 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op
}

void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) {
skipped_initializers_.insert(tensor_name);
// Decrement usage count if this is a known initializer.
// For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names
// without first checking if the value is an initializer.
auto entry = initializer_usage_.find(tensor_name);
if (entry != initializer_usage_.end()) {
--entry->second;
}
}

void ModelBuilder::AddInputToSkip(const std::string& input_name) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ModelBuilder {

InlinedHashMap<std::string, OnnxTensorInfo> input_output_info_;

InlinedHashSet<std::string> skipped_initializers_;
std::unordered_map<std::string, int> initializer_usage_;
InlinedHashSet<std::string> skipped_inputs_;

uint32_t name_token_{0};
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def set_provider_options(name, options):
if len(providers) != len(provider_options):
raise ValueError("'providers' and 'provider_options' should be the same length if both are given.")

if not all([isinstance(provider, str) for provider in providers]):
if not all(isinstance(provider, str) for provider in providers):
raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.")

if not all([isinstance(options_for_provider, dict) for options_for_provider in provider_options]):
if not all(isinstance(options_for_provider, dict) for options_for_provider in provider_options):
raise ValueError("'provider_options' values must be dicts.")

for name, options in zip(providers, provider_options, strict=False):
Expand All @@ -150,7 +150,7 @@ def set_provider_options(name, options):
else:
for provider in providers:
if isinstance(provider, str):
set_provider_options(provider, dict())
set_provider_options(provider, {})
elif (
isinstance(provider, tuple)
and len(provider) == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def set_dispatch(name):
from difflib import SequenceMatcher as Matcher

valid_names = list(_ke_context.dispatchable.keys())
scored_names = list(reversed(sorted([(Matcher(None, name, a).ratio(), a) for a in valid_names])))
scored_names = sorted([(Matcher(None, name, a).ratio(), a) for a in valid_names], reverse=True)
top10 = "\n ".join([a for _, a in scored_names[:10]])
msg = f"'{name}' is not registered for dispatch. Top 10 matches are:\n {top10}"
print(msg)
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/profile_explorer/profile_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _print_op_kernel_mapping_info(cpu_df, gpu_df, num_runs, csv=None):
# Count op occurrences in the selected runs
op_counts = defaultdict(int)
for op in cpu_df.T.to_dict().values():
identifiers = tuple([op["name"], op["input_type_shape"]])
identifiers = (op["name"], op["input_type_shape"])
op_counts[identifiers] += 1

# Collect kernel stats: count/duration
Expand All @@ -212,15 +212,15 @@ def _print_op_kernel_mapping_info(cpu_df, gpu_df, num_runs, csv=None):
input_type_shape = kernel["input_type_shape"]
kernel_name = kernel["name"]
dimensions = kernel["dimensions"]
identifiers = tuple([op_name, input_type_shape, kernel_name, dimensions])
identifiers = (op_name, input_type_shape, kernel_name, dimensions)
stat_dict[identifiers]["count"] += 1
stat_dict[identifiers]["duration"] += kernel["duration"]

# Create the DataFrame for kernel entries with op correlation info
kernel_list = []
for identifiers, stat in stat_dict.items():
op_name, input_type_shape, kernel_name, dimensions = identifiers
op_count = op_counts.get(tuple([op_name, input_type_shape]))
op_count = op_counts.get((op_name, input_type_shape))
if op_count is None:
continue
kernel_list.append(
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def __init__(
'Conv_4:0': [np.float32(1), np.float32(3.5)]
}
"""
if tensors_range is not None and any(map(lambda t: not isinstance(t, TensorData), tensors_range.values())):
if tensors_range is not None and any(not isinstance(t, TensorData) for t in tensors_range.values()):
raise TypeError(
f"tensors_range contains unexpected types {set(type(v) for v in tensors_range.values())}, not TensorData."
f"tensors_range contains unexpected types { {type(v) for v in tensors_range.values()} }, not TensorData."
)
self.tensors_range = tensors_range
self.nodes_to_quantize = nodes_to_quantize # specific nodes to quantize
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ def compute_data(self) -> TensorsData:

if self.symmetric:
max_absolute_value = np.max([np.abs(min_value_array), np.abs(max_value_array)], axis=0)
pairs.append(tuple([-max_absolute_value, max_absolute_value]))
pairs.append((-max_absolute_value, max_absolute_value))
else:
pairs.append(tuple([min_value_array, max_value_array]))
pairs.append((min_value_array, max_value_array))

new_calibrate_tensors_range = TensorsData(
CalibrationMethod.MinMax, dict(zip(calibrate_tensor_names, pairs, strict=False))
Expand Down Expand Up @@ -823,7 +823,7 @@ def collect_absolute_value(self, name_to_arr):
if isinstance(data_arr, list):
for arr in data_arr:
assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}"
dtypes = set(a.dtype for a in data_arr)
dtypes = {a.dtype for a in data_arr}
assert len(dtypes) == 1, (
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def apply(
# Use type requests to "fix" tensor quantization overrides by adding
# quantization type conversions where necessary.
for tensor_name, type_req in type_requests.items():
all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])])
all_consumers = {node.name for node in self.consumers.get(tensor_name, [])}
has_producer_req = type_req.producer is not None
has_consumer_req = bool(type_req.consumers)

Expand Down
Loading

0 comments on commit 4fa0412

Please sign in to comment.