Skip to content

Commit 701a504

Browse files
3l1facebook-github-bot
authored andcommitted
Enable int16 for op permute (#15256)
Summary: Enable int16 for op permute Reviewed By: Ninja91 Differential Revision: D84948536
1 parent 296e07f commit 701a504

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

backends/arm/operators/op_permute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def define_node(
117117
validate_valid_dtype(
118118
self.target,
119119
[inputs[0], output],
120-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
120+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
121121
output.tosa_spec,
122122
)
123123

backends/arm/test/ops/test_permute.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
TosaPipelineINT,
2020
VgfPipeline,
2121
)
22-
from torchvision.ops import Permute
2322

2423
input_t1 = Tuple[torch.Tensor] # Input x
2524

@@ -42,10 +41,10 @@ class SimplePermute(torch.nn.Module):
4241
def __init__(self, dims: list[int]):
4342
super().__init__()
4443

45-
self.permute = Permute(dims=dims)
44+
self.dims = dims
4645

4746
def forward(self, x):
48-
return self.permute(x)
47+
return torch.permute(x, self.dims)
4948

5049

5150
@common.parametrize("test_data", test_data_suite)

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_cat.py",
2121
"ops/test_linear.py",
2222
"ops/test_mul.py",
23+
"ops/test_permute.py",
2324
"ops/test_slice.py",
2425
"ops/test_sigmoid.py",
2526
"ops/test_sub.py",

0 commit comments

Comments
 (0)