diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 86b684f8c6ebd..21947a22e2b92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape4d = input_names[0] + "_pre_reshape"; + const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - const std::string reshape_node_name = "pre_reshape"; - QnnTensorWrapper rw( - reshape4d, + QnnTensorWrapper reshape_prior_tensor( + reshape_prior_out, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), - "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), + "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_prior", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {input_names[0]}, - {reshape4d}, + {reshape_prior_out}, {}, do_op_validation), - "Failed to create reshape-4d node."); - input_names[0] = reshape4d; + "Failed to create reshape prior node for pool op."); + input_names[0] = reshape_prior_out; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_name = "poolmax2d"; - const std::string pool_out = real_out + "_post_reshape"; - const std::string post_reshape_node_name = "post_reshape"; + const std::string pool_out = real_out + "_reshape_after"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - pool_name, + utils::GetNodeName(node_unit) + "_pool2d", QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape4d}, + {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create QNN Pool node for rank-3 input."); + "Failed to create pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_back_tensor( + QnnTensorWrapper reshape_after_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), + "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - post_reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_after", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape-back node."); + "Failed to create reshape after node for pool op."); return Status::OK(); }