Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info));

bool needs_reshape = false;
const std::string reshape4d = input_names[0] + "_pre_reshape";
const std::string reshape_prior_out = input_names[0] + "_prior_reshape";
if (input_shape.size() == 3) {
needs_reshape = true;
// build new_shape = {N, 1, C, L}
Expand All @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
input_shape[1],
input_shape[2]};

const std::string reshape_node_name = "pre_reshape";
QnnTensorWrapper rw(
reshape4d,
QnnTensorWrapper reshape_prior_tensor(
reshape_prior_out,
QNN_TENSOR_TYPE_NATIVE,
reshape_input_info.qnn_data_type,
reshape_input_info.quant_param.Copy(),
std::move(new_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)),
"Failed to add reshape-4d tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)),
"Failed to add reshape prior tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
reshape_node_name,
utils::GetNodeName(node_unit) + "_reshape_prior",
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Reshape",
QNN_OP_RESHAPE,
{input_names[0]},
{reshape4d},
{reshape_prior_out},
{},
do_op_validation),
"Failed to create reshape-4d node.");
input_names[0] = reshape4d;
"Failed to create reshape prior node for pool op.");
input_names[0] = reshape_prior_out;
input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]};
}

Expand Down Expand Up @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
}
const auto& outputs = node_unit.Outputs();
const std::string real_out = outputs[0].node_arg.Name();
const std::string pool_name = "poolmax2d";
const std::string pool_out = real_out + "_post_reshape";
const std::string post_reshape_node_name = "post_reshape";
const std::string pool_out = real_out + "_reshape_after";
const std::string qnn_op = GetQnnOpType(op_type);
TensorInfo output_info{};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info));
Expand All @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
"Failed to add tensor for pool_out");

ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
pool_name,
utils::GetNodeName(node_unit) + "_pool2d",
QNN_OP_PACKAGE_NAME_QTI_AISW,
qnn_op,
{reshape4d},
{reshape_prior_out},
{pool_out},
std::move(param_tensor_names),
do_op_validation),
"Failed to create QNN Pool node for rank-3 input.");
"Failed to create pool node for rank-3 input.");

std::vector<uint32_t> final_shape3d = output_info.shape;
QnnTensorWrapper reshape_back_tensor(
QnnTensorWrapper reshape_after_tensor(
real_out,
tensor_type,
output_info.qnn_data_type,
output_info.quant_param.Copy(),
std::move(final_shape3d));

ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)),
"Failed to add reshape after tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
post_reshape_node_name,
utils::GetNodeName(node_unit) + "_reshape_after",
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Reshape",
QNN_OP_RESHAPE,
{pool_out},
{real_out},
{},
do_op_validation),
"Failed to create reshape-back node.");
"Failed to create reshape after node for pool op.");

return Status::OK();
}
Expand Down
Loading