From 7d4543defc75db60d042546d40324dfcbf962ef2 Mon Sep 17 00:00:00 2001 From: Tom Allsop Date: Mon, 12 May 2025 13:56:35 +0100 Subject: [PATCH] Arm backend: Adjust AvgPool2d padding when window is not divisible by stride * AvgPool2dVisitor will adjust the padding so the pooling window is divisible by the stride * Improve tests in test_max_pool.py Signed-off-by: Tom Allsop Change-Id: I068f025a961a4671ad6727ff156cbad87bec2c08 --- backends/arm/operators/op_avg_pool2d.py | 29 +++++++++++++++ backends/arm/operators/op_max_pool2d.py | 27 +++----------- .../operators/operator_validation_utils.py | 37 +++++++++++++++++++ backends/arm/test/ops/test_avg_pool2d.py | 12 ++++++ backends/arm/test/ops/test_max_pool.py | 2 + 5 files changed, 85 insertions(+), 22 deletions(-) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 9eb533b7968..dc455206f75 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -17,6 +17,7 @@ register_node_visitor, ) from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, validate_num_inputs, validate_same_dtype, ) @@ -63,6 +64,20 @@ def _build_generic_avgpool2d( except IndexError: pad_size_list = [0, 0, 0, 0] + # Adjust the padding as necessary + pad_size_list[1] = adjust_pooling_pad_if_needed( + input_tensor.shape[2], + kernel_size_list[0], + stride_size_list[0], + pad_size_list[1], + ) + pad_size_list[3] = adjust_pooling_pad_if_needed( + input_tensor.shape[3], + kernel_size_list[1], + stride_size_list[1], + pad_size_list[3], + ) + attr = ts.TosaSerializerAttribute() attr.PoolAttribute( kernel=kernel_size_list, @@ -192,6 +207,20 @@ def _build_generic_avgpool2d( except IndexError: pad_size_list = [0, 0, 0, 0] + # Adjust the padding as necessary + pad_size_list[1] = adjust_pooling_pad_if_needed( + input_tensor.shape[2], + kernel_size_list[0], + stride_size_list[0], + pad_size_list[1], + ) + pad_size_list[3] = adjust_pooling_pad_if_needed( + input_tensor.shape[3], + kernel_size_list[1], + stride_size_list[1], + pad_size_list[3], + ) + attr = ts.TosaSerializerAttribute() attr.AvgPool2dAttribute( kernel=kernel_size_list, diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 8a37627a416..170c16261a7 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -17,6 +17,7 @@ register_node_visitor, ) from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, validate_num_inputs, validate_same_dtype, ) @@ -24,24 +25,6 @@ from executorch.backends.arm.tosa_specification import TosaSpecification -# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible: -# `(input + 2 * pad - kernel_size) / stride` -# PyTorch however, does not require this, so as needed, we must adjust the padding. -def adjust_pad_if_needed( - input_size: int, kernel_size: int, stride: int, pad: int -) -> int: - if pad == 0: - return pad - - mod_remainder = (input_size + 2 * pad - kernel_size) % stride - - # No need to adjust - if mod_remainder == 0: - return pad - - return pad - mod_remainder - - @register_node_visitor class MaxPool2dVisitor_0_80(NodeVisitor): target = "aten.max_pool2d.default" @@ -82,13 +65,13 @@ def define_node( pad_size_list = [0, 0, 0, 0] # Adjust the padding as necessary - pad_size_list[1] = adjust_pad_if_needed( + pad_size_list[1] = adjust_pooling_pad_if_needed( input_tensor.shape[2], kernel_size[0], stride[0], pad_size_list[1], ) - pad_size_list[3] = adjust_pad_if_needed( + pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size[1], stride[1], @@ -167,13 +150,13 @@ def define_node( pad_size_list = [0, 0, 0, 0] # Adjust the padding as necessary - pad_size_list[1] = adjust_pad_if_needed( + pad_size_list[1] = adjust_pooling_pad_if_needed( input_tensor.shape[2], kernel_size[0], stride[0], pad_size_list[1], ) - pad_size_list[3] = adjust_pad_if_needed( + pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size[1], stride[1], diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index d15bb65ba77..f0c9af2a137 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -99,3 +99,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any]): f"{op_name}: Expected all tensors to have dtype {reference_dtype}, but " f"found inconsistent dtype {tensor.dtype}." ) + + +def adjust_pooling_pad_if_needed( + input_size: int, kernel_size: int, stride: int, pad: int +) -> int: + """ + Calculates the padding that needs to be removed to a pooling window to make it + divisible by the kernels stride. All inputs should correspond to the same dimension. + + Parameters: + ----------- + input_size : int + The size of the input to the operator. + + kernel_size : int + The size of the kernel. + + stride : int + The size of the stride. + + pad : int + The amount of padding. + + Output: + ------- + An int, representing the padding to remove to make the window divisible. + """ + if pad == 0: + return pad + + mod_remainder = (input_size + 2 * pad - kernel_size) % stride + + # No need to adjust + if mod_remainder == 0: + return pad + + return pad - mod_remainder diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 65c1830b9b2..9927a6d2895 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -59,6 +59,18 @@ def forward(self, x): AvgPool2d((4, 6), (1, 2), (2, 3)), (torch.rand(1, 16, 50, 32),), ), + "non_divisible_window": lambda: ( + AvgPool2d(3, 2, 1), + (torch.rand(1, 16, 112, 112),), + ), + "non_divisible_window_height": lambda: ( + AvgPool2d(3, (2, 1), 1), + (torch.rand(1, 16, 56, 56),), + ), + "non_divisible_window_width": lambda: ( + AvgPool2d(3, (1, 2), 1), + (torch.rand(1, 16, 56, 56),), + ), } diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index a1fd3ea30ec..7e9c90e983f 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -26,6 +26,8 @@ "ones": lambda: (torch.ones(1, 16, 50, 32), [4, 2, 0]), "rand": lambda: (torch.rand(1, 16, 52, 16), [4, 3, 0]), "non_divisible": lambda: (torch.rand(1, 16, 112, 112), [3, 2, 1]), + "non_divisible_window_height": lambda: (torch.rand(1, 16, 56, 56), [3, (2, 1), 1]), + "non_divisible_window_width": lambda: (torch.rand(1, 16, 56, 56), [3, (1, 2), 1]), } test_data_suite_mult_batches = {