Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,16 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

static std::vector<uint32_t> AmendOutputShapeForRank3Pool(
Status AmendOutputShapeForRank3Pool(
gsl::span<const uint32_t> input_shape, // {N, H, W, C}
gsl::span<const uint32_t> kernel_shape, // {k_h, k_w}
gsl::span<const uint32_t> strides, // {s_h, s_w}
gsl::span<const uint32_t> pads) {
assert(input_shape.size() == 4 &&
kernel_shape.size() == 2 &&
strides.size() == 2 &&
pads.size() == 4);
gsl::span<const uint32_t> pads,
std::vector<uint32_t>& output_shape) {
ORT_RETURN_IF_NOT(input_shape.size() == 4, "Expecting input rank 4 for amending 1D Pool output shape.");
ORT_RETURN_IF_NOT(kernel_shape.size() == 2, "Expecting kernel size 2 for amending 1D Pool output shape.");
ORT_RETURN_IF_NOT(strides.size() == 2, "Expecting strides size 2 for amending 1D Pool output shape.");
ORT_RETURN_IF_NOT(pads.size() == 4, "Expecting pad size 4 for amending 1D Pool output shape.");

const uint32_t N = input_shape[0];
const uint32_t H = input_shape[1];
Expand All @@ -120,7 +121,13 @@ static std::vector<uint32_t> AmendOutputShapeForRank3Pool(
? 0
: (padded_W - kernel_shape[1]) / strides[1] + 1;

return {N, out_H, out_W, C};
output_shape.resize(4);
output_shape[0] = N;
output_shape[1] = out_H;
output_shape[2] = out_W;
output_shape[3] = C;

return Status::OK();
}

Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper,
Expand Down Expand Up @@ -177,10 +184,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper,
if (auto_pad.compare("NOTSET") != 0) {
if (output_shape.size() == 3) {
// Calculate rank-4 output shape for rank-3 input.
output_shape = AmendOutputShapeForRank3Pool(input_shape,
filter_size,
stride,
pad_amount);
ORT_RETURN_IF_ERROR(AmendOutputShapeForRank3Pool(input_shape, filter_size, stride, pad_amount, output_shape));
}

for (size_t axis = 0; axis < rank - 2; ++axis) {
Expand Down Expand Up @@ -365,14 +369,6 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
std::move(output_shape)));
}

// Calculate rank-4 output shape for rank-3 input.
std::vector<uint32_t> onnx_in_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape");
if (onnx_in_shape.size() == 3) {
onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]};
}
auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount);

// Construct param wrappers.
ORT_RETURN_IF_NOT(SetPoolParam(node_unit,
param_filter_size,
Expand Down Expand Up @@ -443,6 +439,16 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra

return Status::OK();
}

// Calculate rank-4 output shape for rank-3 input.
std::vector<uint32_t> onnx_in_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape");
if (onnx_in_shape.size() == 3) {
onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]};
}
std::vector<uint32_t> pooled_shape;
ORT_RETURN_IF_ERROR(AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount, pooled_shape));

const auto& outputs = node_unit.Outputs();
const std::string real_out = outputs[0].node_arg.Name();
const std::string pool_out = utils::GetUniqueName(real_out, "_reshape_after");
Expand Down
Loading