diff --git a/backends/xnnpack/operators/op_squeeze.py b/backends/xnnpack/operators/op_squeeze.py index e857b6c68bb..8ed5aa36ae6 100644 --- a/backends/xnnpack/operators/op_squeeze.py +++ b/backends/xnnpack/operators/op_squeeze.py @@ -7,7 +7,6 @@ from typing import cast, Dict import torch -from executorch.backends.transforms import get_shape from executorch.backends.xnnpack.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -53,7 +52,21 @@ def define_node( "val" in input_node.meta, "Missing val in tensor metadata for input when serializing XNNStaticReshape node", ) - new_shape = get_shape(input_node)[:-1] + dynamic_shape = node.meta["val"].shape + new_shape = [] + + num_dynamic_dims = 0 + for dim in dynamic_shape: + if isinstance(dim, torch.SymInt): + num_dynamic_dims += 1 + new_shape.append(0) + else: + new_shape.append(dim) + + check_or_raise( + num_dynamic_dims <= 1, + "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ", + ) ser_node = XNode( xnode_union=XNNStaticReshape( @@ -101,7 +114,21 @@ def define_node( "val" in input_node.meta, "Missing val in tensor metadata for input when serializing XNNStaticReshape node", ) - new_shape = get_shape(input_node) + [1] + dynamic_shape = node.meta["val"].shape + new_shape = [] + + num_dynamic_dims = 0 + for dim in dynamic_shape: + if isinstance(dim, torch.SymInt): + num_dynamic_dims += 1 + new_shape.append(0) + else: + new_shape.append(dim) + + check_or_raise( + num_dynamic_dims <= 1, + "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ", + ) ser_node = XNode( xnode_union=XNNStaticReshape( diff --git a/backends/xnnpack/test/models/w2l.py b/backends/xnnpack/test/models/w2l.py index c95fc29d8cc..7f63d0b15f1 100644 --- a/backends/xnnpack/test/models/w2l.py +++ b/backends/xnnpack/test/models/w2l.py @@ -15,13 +15,15 @@ class TestW2L(unittest.TestCase): batch_size = 10 input_frames = 700 vocab_size = 4096 + num_features = 1 wav2letter = models.Wav2Letter(num_classes=vocab_size).eval() - model_inputs = (torch.randn(batch_size, 1, input_frames),) + model_inputs = (torch.randn(batch_size, num_features, input_frames),) + dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},) def test_fp32_w2l(self): ( - Tester(self.wav2letter, self.model_inputs) + Tester(self.wav2letter, self.model_inputs, self.dynamic_shape) .export() .to_edge() .partition() @@ -34,12 +36,12 @@ def test_fp32_w2l(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method_and_compare_outputs() + .run_method_and_compare_outputs(num_runs=10) ) def test_qs8_w2l(self): ( - Tester(self.wav2letter.eval(), self.model_inputs) + Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape) .quantize() .export() .to_edge() @@ -53,5 +55,5 @@ def test_qs8_w2l(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method_and_compare_outputs() + .run_method_and_compare_outputs(num_runs=10) ) diff --git a/backends/xnnpack/test/ops/conv1d.py b/backends/xnnpack/test/ops/conv1d.py index 50f9aa3a996..6558fd673ff 100644 --- a/backends/xnnpack/test/ops/conv1d.py +++ b/backends/xnnpack/test/ops/conv1d.py @@ -81,9 +81,15 @@ def forward(self, x): z = torch.add(y, z) return z - def _test_conv1d(self, module, inputs, conv_count, quantized=False): + def _test_conv1d( + self, module, inputs, conv_count, quantized=False, dynamic_shape=None + ): ( - (Tester(module, inputs).quantize() if quantized else Tester(module, inputs)) + ( + Tester(module, inputs, dynamic_shape).quantize() + if quantized + else Tester(module, inputs) + ) .export() .check_count({"torch.ops.aten.convolution.default": conv_count}) .to_edge() @@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False): ) def test_fp16_conv1d(self): - inputs = (torch.randn(1, 2, 4).to(torch.float16),) - self._test_conv1d(self.Conv1d(dtype=torch.float16), inputs, conv_count=1) + inputs = (torch.randn(2, 2, 4).to(torch.float16),) + dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},) + self._test_conv1d( + self.Conv1d(dtype=torch.float16), + inputs, + conv_count=1, + dynamic_shape=dynamic_shapes, + ) def test_fp32_conv1d(self): - inputs = (torch.randn(1, 2, 4),) - self._test_conv1d(self.Conv1d(), inputs, 1) + inputs = (torch.randn(2, 2, 4),) + dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},) + self._test_conv1d(self.Conv1d(), inputs, 1, dynamic_shape=dynamic_shapes) def test_fp32_conv1d_batchnorm_seq(self): - inputs = (torch.randn(1, 2, 4),) - self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2) + inputs = (torch.randn(2, 2, 4),) + dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},) + self._test_conv1d( + self.Conv1dBatchNormSequential(), inputs, 2, dynamic_shape=dynamic_shapes + ) def test_qs8_conv1d(self): - inputs = (torch.randn(1, 2, 4),) - self._test_conv1d(self.Conv1d(), inputs, 1, quantized=True) + inputs = (torch.randn(2, 2, 4),) + dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},) + self._test_conv1d( + self.Conv1d(), inputs, 1, quantized=True, dynamic_shape=dynamic_shapes + ) def test_qs8_conv1d_batchnorm_seq(self): - inputs = (torch.randn(1, 2, 4),) - self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2, quantized=True) + inputs = (torch.randn(2, 2, 4),) + dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},) + self._test_conv1d( + self.Conv1dBatchNormSequential(), + inputs, + 2, + quantized=True, + dynamic_shape=dynamic_shapes, + )