Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
// Insert cast node.
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
QNN_OP_CAST,
{cast_node_info.input_name},
{cast_node_info.output_name},
{}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ static bool FixStaticIndices(const std::vector<uint8_t>& onnx_bytes,
}

// Gets the size of input0 on the axis dimension.
static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
int64_t default_axis_value,
/*out*/ int64_t& axis_dim_value) {
static Status GetInput0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
int64_t default_axis_value,
/*out*/ int64_t& axis_dim_value) {
const auto& input0 = node_unit.Inputs()[0];
std::vector<uint32_t> input0_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input0.node_arg, input0_shape),
Expand All @@ -111,8 +111,9 @@ static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper,

// Processes the indices input to Gather operators.
//
// In general, QNN only supports int32/uint32 indices. QNN EP has to add Cast for dynamic int64 indices or
// convert static int64 indices to int32.
// QNN only supports int32 / uint32 as indices tensor data types.
// When indices tensor is an initializer, statically cast values int64 -> int32.
// When dynamic input, add explicit QNN Cast node for int64 -> int32 conversion.
//
// The HTP backend only supports dynamic int64 indices if they are a graph input.
static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
Expand All @@ -121,7 +122,7 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) {
const auto& input_name = indices_input.node_arg.Name();
const auto& indices_tensor_name = indices_input.node_arg.Name();

TensorInfo indices_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info));
Expand All @@ -146,28 +147,44 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
}
}

Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name);
std::vector<uint32_t> cast_output_shape(indices_info.shape);
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
if (qnn_model_wrapper.IsQnnTensorWrapperExist(indices_tensor_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << indices_tensor_name;
} else {
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
std::move(indices_info.shape), std::move(qnn_indices_bytes));
QnnTensorWrapper input_tensorwrapper(indices_tensor_name,
qnn_model_wrapper.GetTensorType(indices_tensor_name),
indices_info.qnn_data_type, QnnQuantParamsWrapper(),
std::move(indices_info.shape),
std::move(qnn_indices_bytes));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
}

// Insert QNN Cast op to convert dynamic indices from int64 to int32.
std::string indices_input_name(input_name);
std::string indices_casted_name{indices_tensor_name};
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
assert(!indices_info.is_initializer);

ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddInt64CastNode(input_name, indices_input_name,
std::move(cast_output_shape),
do_op_validation));
indices_casted_name += "_int32";
if (qnn_model_wrapper.IsQnnTensorWrapperExist(indices_casted_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << indices_casted_name;
} else {
QnnTensorWrapper indices_cast_tensor(indices_casted_name,
QNN_TENSOR_TYPE_NATIVE,
QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(),
std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(indices_cast_tensor)),
"Failed to add gather indices cast tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_casted_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_CAST,
{indices_tensor_name},
{indices_casted_name},
{},
do_op_validation),
"Failed to add gather indices cast node.");
}
}

input_names.push_back(indices_input_name);

input_names.push_back(indices_casted_name);
return Status::OK();
}

Expand All @@ -181,8 +198,7 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

int64_t input0_axis_dim = 0;
ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim));

ORT_RETURN_IF_ERROR(GetInput0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis_value=*/0, input0_axis_dim));
return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation);
}

Expand Down Expand Up @@ -312,7 +328,6 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
QnnTensorWrapper gather_output_wrapper(gather_output_name, tensor_type, qnn_data_type, quantize_param.Copy(),
std::move(qnn_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(gather_output_wrapper)), "Failed to add tensor.");

ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
Expand All @@ -328,7 +343,6 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
QnnTensorWrapper reshape_output(output_name, reshape_tensor_type, qnn_data_type, std::move(quantize_param),
std::move(target_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_output)), "Failed to add tensor.");
const static std::string qnn_node_type = "Reshape";
std::string node_output_name = output_name;

if (needs_int64_cast) {
Expand All @@ -337,7 +351,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
}
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(output_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
qnn_node_type,
QNN_OP_RESHAPE,
{gather_output_name},
{node_output_name},
{},
Expand All @@ -350,7 +364,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
// Insert cast node.
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
QNN_OP_CAST,
{cast_node_info.input_name},
{cast_node_info.output_name},
{}),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
"Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
QNN_OP_CAST,
{cast_input_name},
{output_name},
{}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Status TransposeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mode
// Insert cast node.
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
QNN_OP_CAST,
{cast_node_info.input_name},
{cast_node_info.output_name},
{}),
Expand Down
18 changes: 0 additions & 18 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,24 +236,6 @@ class QnnModelWrapper {
tensor_data_type, quantize_param, do_op_validation, is_for_input, is_for_output);
}

Status AddInt64CastNode(const std::string& input_name, std::string& cast_output_name,
std::vector<uint32_t>&& cast_output_shape, bool do_op_validation) {
cast_output_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(cast_output_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(), std::move(cast_output_shape));
ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(CreateQnnNode(cast_output_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{cast_output_name},
{},
do_op_validation),
"Failed to add node.");

return Status::OK();
}

Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
std::vector<uint8_t>& unpacked_tensor) const;

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
ORT_TSTR("sce_none_weights_log_prob_expanded"),
ORT_TSTR("sce_none_weights_expanded"),
ORT_TSTR("convtranspose_3d"),
ORT_TSTR("gather_elements_negative_indices")};
ORT_TSTR("gather_elements_negative_indices"),
ORT_TSTR("rotary_embedding_3d_input_expanded"),
ORT_TSTR("rotary_embedding_expanded"),
ORT_TSTR("rotary_embedding_interleaved_expanded")};

std::unordered_set<std::basic_string<ORTCHAR_T>> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests));

Expand Down
Loading