Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 111 additions & 22 deletions onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
return QLinearOpType::QLinearMatMul;
else if (op_type == "QLinearAdd")
return QLinearOpType::QLinearAdd;
else if (op_type == "QLinearSigmoid")
return QLinearOpType::QLinearSigmoid;

return QLinearOpType::Unknown;
}
Expand Down Expand Up @@ -232,8 +234,10 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co

std::unique_ptr<uint8_t[]> unpacked_tensor;
size_t tensor_byte_size;
auto status = onnxruntime::utils::UnpackInitializerData(zero_tensor, node.ModelPath(),
unpacked_tensor, tensor_byte_size);
auto status = onnxruntime::utils::UnpackInitializerData(
zero_tensor,
node.ModelPath(),
unpacked_tensor, tensor_byte_size);
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "QLinearConv erro when unpack zero tensor:" << status.ErrorMessage();
return false;
Expand Down Expand Up @@ -264,6 +268,24 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co
return true;
}

float GetQuantizationScale(const InitializedTensorSet& initializers, const Node& node, size_t idx) {
const auto& scale_tensor = *initializers.at(node.InputDefs()[idx]->Name());
return GetTensorFloatData(scale_tensor)[0];
}

common::Status GetQuantizationZeroPoint(const InitializedTensorSet& initializers,
const Node& node, size_t idx, int32_t& zero_point) {
std::unique_ptr<uint8_t[]> unpacked_tensor;
size_t tensor_byte_size;
const auto& zero_point_tensor = *initializers.at(node.InputDefs()[idx]->Name());
ORT_RETURN_IF_ERROR(
onnxruntime::utils::UnpackInitializerData(zero_point_tensor, node.ModelPath(),
unpacked_tensor, tensor_byte_size));
// Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI
zero_point = static_cast<int32_t>(unpacked_tensor.get()[0]);
return Status::OK();
}

#define GET_TENSOR_DATA(FUNC_NAME, ELEMENT_TYPE, DATA) \
const ELEMENT_TYPE* GetTensor##FUNC_NAME(const ONNX_NAMESPACE::TensorProto& tensor) { \
return tensor.DATA().empty() \
Expand Down Expand Up @@ -348,13 +370,13 @@ void GetFlattenOutputShape(const Node& node, const Shape& input_shape, int32_t&
dim_2 = std::accumulate(input_shape.cbegin() + axis, input_shape.cend(), 1, std::multiplies<int32_t>());
}

bool IsValidSupportedNodesVec(const std::vector<size_t>& supported_node_vec, const GraphViewer& graph_viewer) {
if (supported_node_vec.empty())
bool IsValidSupportedNodesGroup(const std::vector<size_t>& supported_node_group, const GraphViewer& graph_viewer) {
if (supported_node_group.empty())
return false;

if (supported_node_vec.size() == 1) {
if (supported_node_group.size() == 1) {
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
const auto* node(graph_viewer.GetNode(node_indices[supported_node_vec[0]]));
const auto* node(graph_viewer.GetNode(node_indices[supported_node_group[0]]));
const auto& op = node->OpType();
// It is not worth it to perform a single Reshape/Flatten/Identity operator
// which is only copying the data in NNAPI
Expand All @@ -368,49 +390,116 @@ bool IsValidSupportedNodesVec(const std::vector<size_t>& supported_node_vec, con
return true;
}

bool IsInternalQuantizedNode(const Node& node) {
// These operators can use uint8 input without specific QLinear version of it
// However, the mode has to be internal to the graph/partition (they cannot consume graph inputs)
static const std::unordered_set<std::string> internal_quantized_op_types =
{
"Transpose",
"Resize",
"Concat",
"MaxPool",
};

if (!Contains(internal_quantized_op_types, node.OpType()))
return false;

int32_t input_type;
ORT_ENFORCE(GetType(*node.InputDefs()[0], input_type));

return input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean quantization isn't relevent here, and maybe this check be outside of IsInternalQuantizationSupported?

// We support some operators running using uint8 internally
// These nodes cannot use a graph input as input since onnx graph input does not carry scale/zero point info
bool IsInternalQuantizationSupported(const Node& node, const std::unordered_set<std::string>& node_outputs_in_group) {
const auto& op_type = node.OpType();

// The node's input(s) have to be an output of node(s) within the group
// If not, then this node is using graph/partition input(s) as input(s)
const auto& input_defs = node.InputDefs();

// We only need to check input0 for all operators except "Concat"
bool check_all_inputs = op_type == "Concat";

for (size_t i = 0; i < (check_all_inputs ? input_defs.size() : 1); i++) {
if (!Contains(node_outputs_in_group, input_defs[i]->Name())) {
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type: [" << op_type
<< "] has input [" << input_defs[i]->Name()
<< "] does not support using graph input(quantized) as node input";
return false;
}
}

return true;
}

bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const OpSupportCheckParams& params) {
const auto& op_support_checkers = GetOpSupportCheckers();
if (Contains(op_support_checkers, node.OpType())) {
const auto* op_support_checker = op_support_checkers.at(node.OpType());
return op_support_checker->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, params);
} else {
if (!Contains(op_support_checkers, node.OpType()))
return false;
}

const auto* op_support_checker = op_support_checkers.at(node.OpType());
return op_support_checker->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, params);
}

bool IsNodeSupportedInternal(const Node& node, const GraphViewer& graph_viewer,
const OpSupportCheckParams& params,
const std::unordered_set<std::string>& node_outputs_in_group) {
if (!IsNodeSupported(node, graph_viewer, params))
return false;

// We also want to check if the node is supported as an internal quantized node
if (IsInternalQuantizedNode(node))
return IsInternalQuantizationSupported(node, node_outputs_in_group);
else // This is not a internal quantized node, it is supported
return true;
}

std::vector<std::vector<size_t>> GetSupportedNodes(const GraphViewer& graph_viewer, const OpSupportCheckParams& params) {
std::vector<std::vector<size_t>> supported_node_vecs;
std::vector<std::vector<size_t>> supported_node_groups;
if (params.android_sdk_ver < ORT_NNAPI_MIN_API_LEVEL) {
LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because Android API level [" << params.android_sdk_ver
<< "] is lower than minimal supported API level [" << ORT_NNAPI_MIN_API_LEVEL
<< "] of this build for NNAPI";
return supported_node_vecs;
return supported_node_groups;
}

std::vector<size_t> supported_node_vec;
// This holds the supported node's topological index
std::vector<size_t> supported_node_group;
// This holds the NodeIndex of the nodes in the above group
std::unordered_set<std::string> node_outputs_in_group;
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer.GetNode(node_indices[i]));
bool supported = IsNodeSupported(*node, graph_viewer, params);
bool supported = IsNodeSupportedInternal(*node, graph_viewer, params, node_outputs_in_group);
LOGS_DEFAULT(VERBOSE) << "Operator type: [" << node->OpType()
<< "] index: [" << i
<< "] name: [" << node->Name()
<< "] supported: [" << supported
<< "]";
if (supported) {
supported_node_vec.push_back(i);
supported_node_group.push_back(i);

// We want to put all the output names of nodes in the current group for easy query
// See IsInternalQuantizationSupported()
for (const auto* output : node->OutputDefs()) {
node_outputs_in_group.insert(output->Name());
}
} else {
if (IsValidSupportedNodesVec(supported_node_vec, graph_viewer)) {
supported_node_vecs.push_back(supported_node_vec);
supported_node_vec.clear();
if (IsValidSupportedNodesGroup(supported_node_group, graph_viewer)) {
supported_node_groups.push_back(supported_node_group);
}

supported_node_group.clear();
node_outputs_in_group.clear();
}
}

if (IsValidSupportedNodesVec(supported_node_vec, graph_viewer))
supported_node_vecs.push_back(supported_node_vec);
if (IsValidSupportedNodesGroup(supported_node_group, graph_viewer))
supported_node_groups.push_back(supported_node_group);

return supported_node_vecs;
return supported_node_groups;
}

std::string Shape2String(const std::vector<uint32_t>& shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ enum class QLinearOpType : uint8_t {
QLinearConv,
QLinearMatMul,
QLinearAdd,
QLinearSigmoid,
// Not yet supported
// QLinearAveragePool,
// QLinearMul,
Expand Down Expand Up @@ -107,6 +108,11 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const
bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, const Node& node,
const std::vector<size_t>& indices);

float GetQuantizationScale(const InitializedTensorSet& initializers, const Node& node, size_t idx);

common::Status GetQuantizationZeroPoint(const InitializedTensorSet& initializers,
const Node& node, size_t idx, int32_t& zero_point) ORT_MUST_USE_RESULT;

// Get initialize tensort float/int32/int64 data without unpacking
// TODO, move to ort framework
const float* GetTensorFloatData(const ONNX_NAMESPACE::TensorProto& tensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ std::unordered_map<std::string, vector<const Node*>> GetAllQuantizedOpInputs(con
for (const auto& node_idx : node_indices) {
const auto* node(graph_viewer.GetNode(node_idx));
auto qlinear_op_type = GetQLinearOpType(*node);
if (qlinear_op_type == QLinearOpType::DequantizeLinear || IsQLinearBinaryOp(qlinear_op_type)) {

// Not a qlinear op
if (qlinear_op_type == QLinearOpType::Unknown)
continue;

// All qlinear ops EXCEPT QuantizeLinear has quantized input
if (qlinear_op_type != QLinearOpType::QuantizeLinear) {
const auto& input_name = node->InputDefs()[0]->Name();
if (Contains(all_quantized_op_inputs, input_name))
all_quantized_op_inputs.at(input_name).push_back(node);
Expand Down Expand Up @@ -293,7 +299,7 @@ Status ModelBuilder::RegisterModelInputs() {
if (!Contains(all_quantized_op_inputs, input_name)) {
// We current do not support uint8 input if it is not a quantized input
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The input of graph doesn't have valid type, name: ", input_name,
"The input of graph has unsupported quantized type, name: ", input_name,
" type: ", type_proto->tensor_type().elem_type());
}

Expand All @@ -305,7 +311,7 @@ Status ModelBuilder::RegisterModelInputs() {
default: {
// TODO: support other type
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The input of graph doesn't have valid type, name: ", input_name,
"The input of graph has unsupported type, name: ", input_name,
" type: ", type_proto->tensor_type().elem_type());
}
}
Expand Down Expand Up @@ -369,6 +375,7 @@ void ModelBuilder::RegisterModelShaper() {
Status ModelBuilder::AddNewOperand(const std::string& name,
const OperandType& operand_type,
bool is_nhwc, uint32_t& index) {
LOGS_DEFAULT(VERBOSE) << "operand name: " << name;
ORT_RETURN_IF_ERROR(AddNewNNAPIOperand(operand_type, index));
RegisterOperand(name, index, operand_type, is_nhwc);
return Status::OK();
Expand Down Expand Up @@ -535,6 +542,12 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {

int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;

// We do not support activation fusion for quantized operators for now
auto qlinear_op_type = GetQLinearOpType(node);
if (qlinear_op_type != QLinearOpType::Unknown)
return fuse_code;

for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
Expand Down
Loading