diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index e248034f225ec..94ace606ac75a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -348,7 +348,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, // Insert cast node. ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - "Cast", + QNN_OP_CAST, {cast_node_info.input_name}, {cast_node_info.output_name}, {}), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index f2992196f7811..e39f38fb020dc 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -86,10 +86,10 @@ static bool FixStaticIndices(const std::vector& onnx_bytes, } // Gets the size of input0 on the axis dimension. -static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - int64_t default_axis_value, - /*out*/ int64_t& axis_dim_value) { +static Status GetInput0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + int64_t default_axis_value, + /*out*/ int64_t& axis_dim_value) { const auto& input0 = node_unit.Inputs()[0]; std::vector input0_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input0.node_arg, input0_shape), @@ -111,8 +111,9 @@ static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper, // Processes the indices input to Gather operators. // -// In general, QNN only supports int32/uint32 indices. QNN EP has to add Cast for dynamic int64 indices or -// convert static int64 indices to int32. +// QNN only supports int32 / uint32 as indices tensor data types. +// When indices tensor is an initializer, statically cast values int64 -> int32. +// When dynamic input, add explicit QNN Cast node for int64 -> int32 conversion. // // The HTP backend only supports dynamic int64 indices if they are a graph input. static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper, @@ -121,7 +122,7 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger, std::vector& input_names, bool do_op_validation) { - const auto& input_name = indices_input.node_arg.Name(); + const auto& indices_tensor_name = indices_input.node_arg.Name(); TensorInfo indices_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info)); @@ -146,28 +147,44 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper, } } - Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name); std::vector cast_output_shape(indices_info.shape); - if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { - LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name; + if (qnn_model_wrapper.IsQnnTensorWrapperExist(indices_tensor_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << indices_tensor_name; } else { - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(), - std::move(indices_info.shape), std::move(qnn_indices_bytes)); + QnnTensorWrapper input_tensorwrapper(indices_tensor_name, + qnn_model_wrapper.GetTensorType(indices_tensor_name), + indices_info.qnn_data_type, QnnQuantParamsWrapper(), + std::move(indices_info.shape), + std::move(qnn_indices_bytes)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); } // Insert QNN Cast op to convert dynamic indices from int64 to int32. - std::string indices_input_name(input_name); + std::string indices_casted_name{indices_tensor_name}; if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) { assert(!indices_info.is_initializer); - - ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddInt64CastNode(input_name, indices_input_name, - std::move(cast_output_shape), - do_op_validation)); + indices_casted_name += "_int32"; + if (qnn_model_wrapper.IsQnnTensorWrapperExist(indices_casted_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << indices_casted_name; + } else { + QnnTensorWrapper indices_cast_tensor(indices_casted_name, + QNN_TENSOR_TYPE_NATIVE, + QNN_DATATYPE_INT_32, + QnnQuantParamsWrapper(), + std::move(cast_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(indices_cast_tensor)), + "Failed to add gather indices cast tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_casted_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CAST, + {indices_tensor_name}, + {indices_casted_name}, + {}, + do_op_validation), + "Failed to add gather indices cast node."); + } } - - input_names.push_back(indices_input_name); - + input_names.push_back(indices_casted_name); return Status::OK(); } @@ -181,8 +198,7 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); int64_t input0_axis_dim = 0; - ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim)); - + ORT_RETURN_IF_ERROR(GetInput0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis_value=*/0, input0_axis_dim)); return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation); } @@ -312,7 +328,6 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w QnnTensorWrapper gather_output_wrapper(gather_output_name, tensor_type, qnn_data_type, quantize_param.Copy(), std::move(qnn_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(gather_output_wrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, GetQnnOpType(node_unit.OpType()), @@ -328,7 +343,6 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w QnnTensorWrapper reshape_output(output_name, reshape_tensor_type, qnn_data_type, std::move(quantize_param), std::move(target_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_output)), "Failed to add tensor."); - const static std::string qnn_node_type = "Reshape"; std::string node_output_name = output_name; if (needs_int64_cast) { @@ -337,7 +351,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w } ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - qnn_node_type, + QNN_OP_RESHAPE, {gather_output_name}, {node_output_name}, {}, @@ -350,7 +364,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w // Insert cast node. ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - "Cast", + QNN_OP_CAST, {cast_node_info.input_name}, {cast_node_info.output_name}, {}), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 1c22bf55c914d..b87cdd4e25f08 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -257,7 +257,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_input_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - "Cast", + QNN_OP_CAST, {cast_input_name}, {output_name}, {}), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc index 5b26c20cef825..3498aa92032f3 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc @@ -154,7 +154,7 @@ Status TransposeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mode // Insert cast node. ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - "Cast", + QNN_OP_CAST, {cast_node_info.input_name}, {cast_node_info.output_name}, {}), diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index fee6ecc8918a9..a44163563b430 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -236,24 +236,6 @@ class QnnModelWrapper { tensor_data_type, quantize_param, do_op_validation, is_for_input, is_for_output); } - Status AddInt64CastNode(const std::string& input_name, std::string& cast_output_name, - std::vector&& cast_output_shape, bool do_op_validation) { - cast_output_name = input_name + "_ort_qnn_ep_cast"; - QnnTensorWrapper cast_output(cast_output_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, - QnnQuantParamsWrapper(), std::move(cast_output_shape)); - ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(cast_output)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(CreateQnnNode(cast_output_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - "Cast", - {input_name}, - {cast_output_name}, - {}, - do_op_validation), - "Failed to add node."); - - return Status::OK(); - } - Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, std::vector& unpacked_tensor) const; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index bef0bdd5295be..82b2b85ad6779 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -916,7 +916,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_TSTR("sce_none_weights_log_prob_expanded"), ORT_TSTR("sce_none_weights_expanded"), ORT_TSTR("convtranspose_3d"), - ORT_TSTR("gather_elements_negative_indices")}; + ORT_TSTR("gather_elements_negative_indices"), + ORT_TSTR("rotary_embedding_3d_input_expanded"), + ORT_TSTR("rotary_embedding_expanded"), + ORT_TSTR("rotary_embedding_interleaved_expanded")}; std::unordered_set> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests));