Skip to content

Commit b1e3e28

Browse files
authored
Refactor cadence.convolution
Differential Revision: D86473542 Pull Request resolved: #15762
1 parent 76d43bc commit b1e3e28

File tree

5 files changed

+230
-335
lines changed

5 files changed

+230
-335
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,6 @@ def register_fake(
351351
"quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
352352
)
353353

354-
lib.define(
355-
"convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
356-
"int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)"
357-
)
358354
lib.define(
359355
"transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
360356
"int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)"
@@ -489,8 +485,28 @@ def register_fake(
489485
# ------------------------------------ #
490486
# Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out)
491487
lib.define(
492-
"convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, "
493-
"int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
488+
"conv1d(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, "
489+
"int groups) -> Tensor"
490+
)
491+
lib.define(
492+
"conv1d.out(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, "
493+
"int groups, *, Tensor(a!) out) -> Tensor(a!)"
494+
)
495+
lib.define(
496+
"conv2d(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, "
497+
"int groups) -> Tensor"
498+
)
499+
lib.define(
500+
"conv2d.out(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, "
501+
"int groups, *, Tensor(a!) out) -> Tensor(a!)"
502+
)
503+
lib.define(
504+
"conv3d(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, "
505+
"int groups) -> Tensor"
506+
)
507+
lib.define(
508+
"conv3d.out(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, "
509+
"int groups, *, Tensor(a!) out) -> Tensor(a!)"
494510
)
495511
lib.define(
496512
"transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
@@ -2152,41 +2168,102 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
21522168
return src.new_empty(out_size, dtype=src.dtype)
21532169

21542170

2155-
@register_fake("cadence::convolution")
2156-
def convolution_meta(
2171+
@register_fake("cadence::conv1d")
2172+
def conv1d_meta(
21572173
input: torch.Tensor,
21582174
weight: torch.Tensor,
21592175
bias: torch.Tensor,
21602176
stride: Tuple[int],
21612177
padding: Tuple[int],
21622178
dilation: Tuple[int],
21632179
groups: int,
2164-
channel_last: bool = False,
21652180
) -> torch.Tensor:
2166-
if channel_last:
2167-
out_channels, *kernel_size, _ = weight.shape
2168-
else:
2169-
out_channels, _, *kernel_size = weight.shape
2181+
assert (
2182+
len(weight.shape) == 3
2183+
), f"Conv1d expects a 3D weight, got {len(weight.shape)}D"
2184+
out_channels, _, kernel_size = weight.shape
21702185
in_size = input.shape
2171-
# Assert that the input tensor has at least 3 dimensions, and at most 6
2172-
assert len(in_size) > 2
2173-
assert len(in_size) < 6
2186+
assert len(in_size) == 3, f"conv1d expects 3D input, got {len(in_size)}D"
21742187

2175-
# Compute the output tensor size
2176-
output_size = (
2177-
get_conv1d_output_size(
2178-
in_size,
2179-
out_channels,
2180-
stride[0],
2181-
padding[0],
2182-
dilation[0],
2183-
kernel_size[0],
2184-
channel_last,
2185-
)
2186-
if len(in_size) == 3
2187-
else get_conv2d_output_size(
2188-
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
2189-
)
2188+
output_size = get_conv1d_output_size(
2189+
in_size,
2190+
out_channels,
2191+
stride[0],
2192+
padding[0],
2193+
dilation[0],
2194+
kernel_size,
2195+
False,
2196+
)
2197+
2198+
return input.new_empty(output_size, dtype=input.dtype)
2199+
2200+
2201+
@register_fake("cadence::conv2d")
2202+
def conv2d_meta(
2203+
input: torch.Tensor,
2204+
weight: torch.Tensor,
2205+
bias: torch.Tensor,
2206+
stride: Tuple[int],
2207+
padding: Tuple[int],
2208+
dilation: Tuple[int],
2209+
groups: int,
2210+
) -> torch.Tensor:
2211+
assert (
2212+
len(weight.shape) == 4
2213+
), f"Conv2d expects a 4D weight, got {len(weight.shape)}D"
2214+
out_channels, _, *kernel_size = weight.shape
2215+
in_size = input.shape
2216+
assert len(in_size) == 4, f"conv2d expects 4D input, got {len(in_size)}D"
2217+
2218+
output_size = get_conv2d_output_size(
2219+
in_size, out_channels, stride, padding, dilation, kernel_size, False
2220+
)
2221+
2222+
return input.new_empty(output_size, dtype=input.dtype)
2223+
2224+
2225+
@register_fake("cadence::conv3d")
2226+
def conv3d_meta(
2227+
input: torch.Tensor,
2228+
weight: torch.Tensor,
2229+
bias: torch.Tensor,
2230+
stride: Tuple[int, int, int],
2231+
padding: Tuple[int, int, int],
2232+
dilation: Tuple[int, int, int],
2233+
groups: int,
2234+
) -> torch.Tensor:
2235+
assert (
2236+
len(weight.shape) == 5
2237+
), f"Conv3d expects a 5D weight, got {len(weight.shape)}D"
2238+
out_channels, _, *kernel_size = weight.shape
2239+
in_size = input.shape
2240+
assert len(in_size) == 5, f"conv3d expects 5D input, got {len(in_size)}D"
2241+
2242+
# Helper to compute 3D convolution output size
2243+
def get_conv3d_output_size(
2244+
in_size: torch.Size,
2245+
out_channels: int,
2246+
stride: Tuple[int, int, int],
2247+
padding: Tuple[int, int, int],
2248+
dilation: Tuple[int, int, int],
2249+
kernel_size: list[int],
2250+
) -> torch.Size:
2251+
N, C, D, H, W = in_size
2252+
2253+
dout = (D + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[
2254+
0
2255+
] + 1
2256+
hout = (H + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
2257+
1
2258+
] + 1
2259+
wout = (W + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[
2260+
2
2261+
] + 1
2262+
2263+
return torch.Size((N, out_channels, dout, hout, wout))
2264+
2265+
output_size = get_conv3d_output_size(
2266+
in_size, out_channels, stride, padding, dilation, kernel_size
21902267
)
21912268

21922269
return input.new_empty(output_size, dtype=input.dtype)

backends/cadence/aot/ref_implementations.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,48 +1334,53 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
13341334
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
13351335

13361336

1337-
@impl_tracked(m, "convolution")
1338-
def convolution(
1337+
@impl_tracked(m, "conv1d")
1338+
def conv1d(
1339+
input_tensor: torch.Tensor,
1340+
weight: torch.Tensor,
1341+
bias: torch.Tensor,
1342+
stride: tuple[int],
1343+
padding: tuple[int],
1344+
dilation: tuple[int],
1345+
groups: int,
1346+
) -> torch.Tensor:
1347+
conv_out = torch.nn.functional.conv1d(
1348+
input_tensor, weight, bias, stride[0], padding[0], dilation[0], groups
1349+
)
1350+
1351+
return conv_out
1352+
1353+
1354+
@impl_tracked(m, "conv2d")
1355+
def conv2d(
13391356
input_tensor: torch.Tensor,
13401357
weight: torch.Tensor,
13411358
bias: torch.Tensor,
13421359
stride: tuple[int, int],
13431360
padding: tuple[int, int],
13441361
dilation: tuple[int, int],
13451362
groups: int,
1346-
channel_last: bool = False,
13471363
) -> torch.Tensor:
1348-
conv_is_1d = len(input_tensor.shape) == 3
1349-
if channel_last:
1350-
if conv_is_1d:
1351-
input_tensor = input_tensor.movedim(-1, 1).contiguous()
1352-
if len(weight.shape) != 3:
1353-
raise ValueError("Weight tensor must be 3D if input is 3D")
1354-
weight = weight.movedim(-1, 1).contiguous()
1355-
else:
1356-
input_tensor = input_tensor.movedim(-1, -3)
1357-
if len(weight.shape) != 4:
1358-
raise ValueError("Weight tensor must be 4D if input is nd > 3")
1359-
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
1364+
conv_out = torch.nn.functional.conv2d(
1365+
input_tensor, weight, bias, stride, padding, dilation, groups
1366+
)
13601367

1361-
_stride: tuple[int, int] | int = stride
1362-
_padding: tuple[int, int] | int = padding
1363-
_dilation: tuple[int, int] | int = dilation
1368+
return conv_out
13641369

1365-
if conv_is_1d:
1366-
conv = torch.nn.functional.conv1d
1367-
_stride = stride[0]
1368-
_padding = padding[0]
1369-
_dilation = dilation[0]
1370-
else:
1371-
conv = torch.nn.functional.conv2d
13721370

1373-
conv_out = conv(input_tensor, weight, bias, _stride, _padding, _dilation, groups)
1374-
if channel_last:
1375-
if conv_is_1d:
1376-
conv_out = conv_out.movedim(1, -1).contiguous()
1377-
else:
1378-
conv_out = conv_out.movedim(-3, -1).contiguous()
1371+
@impl_tracked(m, "conv3d")
1372+
def conv3d(
1373+
input_tensor: torch.Tensor,
1374+
weight: torch.Tensor,
1375+
bias: torch.Tensor,
1376+
stride: tuple[int, int, int],
1377+
padding: tuple[int, int, int],
1378+
dilation: tuple[int, int, int],
1379+
groups: int,
1380+
) -> torch.Tensor:
1381+
conv_out = torch.nn.functional.conv3d(
1382+
input_tensor, weight, bias, stride, padding, dilation, groups
1383+
)
13791384

13801385
return conv_out
13811386

backends/cadence/aot/replace_ops.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,16 @@ class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass):
452452
def call_operator(self, op, args, kwargs, meta):
453453
op_packet = get_edge_overload_packet(op)
454454
if op_packet not in {
455-
exir_ops.edge.cadence.convolution,
455+
exir_ops.edge.cadence.conv1d,
456+
exir_ops.edge.cadence.conv2d,
457+
exir_ops.edge.cadence.conv3d,
456458
exir_ops.edge.cadence.transposed_convolution,
457459
}:
458460
return super().call_operator(op, args, kwargs, meta)
459461

460462
is_transposed = op_packet == exir_ops.edge.cadence.transposed_convolution
461-
expected_args = 9 if is_transposed else 8
462-
assert len(args) == expected_args
463+
num_expected_args = 9 if is_transposed else 7
464+
assert len(args) == num_expected_args
463465
# Check if the bias is already concrete
464466
if args[2] is not None:
465467
return super().call_operator(op, args, kwargs, meta)
@@ -684,20 +686,28 @@ def call_operator(self, op, args, kwargs, meta):
684686
output_padding,
685687
groups,
686688
) = args
687-
# Currently we only handle conversion to conv1d and conv2d, therefore
689+
# Currently we only handle conversion to conv1d, conv2d, and conv3d, therefore
688690
# verify that the stride, padding, dilation, and output_padding have
689-
# len <=2.
691+
# len <=3.
690692
assert (
691-
len(stride) == len(padding) == len(dilation) == len(output_padding) == 1
692-
) or (
693-
len(stride) == len(padding) == len(dilation) == len(output_padding) == 2
694-
), "Can only map convolution to conv1d and conv2d at present"
695-
696-
target = (
697-
exir_ops.edge.cadence.transposed_convolution.default
698-
if transposed
699-
else exir_ops.edge.cadence.convolution.default
700-
)
693+
(len(stride) == len(padding) == len(dilation) == len(output_padding) == 1)
694+
or (
695+
len(stride) == len(padding) == len(dilation) == len(output_padding) == 2
696+
)
697+
or (
698+
len(stride) == len(padding) == len(dilation) == len(output_padding) == 3
699+
)
700+
), "Can only map convolution to conv1d, conv2d, and conv3d at present"
701+
702+
# Determine if this is 1D, 2D, or 3D convolution based on parameter lengths
703+
if transposed:
704+
target = exir_ops.edge.cadence.transposed_convolution.default
705+
elif len(stride) == 1:
706+
target = exir_ops.edge.cadence.conv1d.default
707+
elif len(stride) == 2:
708+
target = exir_ops.edge.cadence.conv2d.default
709+
else: # len(stride) == 3
710+
target = exir_ops.edge.cadence.conv3d.default
701711

702712
if transposed:
703713
# Flip the height and width dimensions of weight, since we apply a
@@ -756,7 +766,6 @@ def call_operator(self, op, args, kwargs, meta):
756766
padding,
757767
dilation,
758768
groups,
759-
False,
760769
)
761770

762771
return super().call_operator(target, new_args, kwargs, meta)
@@ -778,7 +787,9 @@ class ReplaceTrivialConvWithLinear(ExportPass):
778787
"""
779788

780789
trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
781-
exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
790+
exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default,
791+
exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default,
792+
exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default,
782793
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
783794
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
784795
}
@@ -795,7 +806,7 @@ def call_operator(self, op, args, kwargs, meta):
795806
op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
796807
or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
797808
)
798-
assert (len(args) == 8 and not quantized_op) or (
809+
assert (len(args) == 7 and not quantized_op) or (
799810
len(args) >= 12 and quantized_op
800811
), "Inconsistent args for convolution"
801812
(in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7]
@@ -950,7 +961,9 @@ def call_operator(
950961
meta: NodeMetadata,
951962
) -> ProxyValue:
952963
if op not in {
953-
exir_ops.edge.cadence.convolution.default,
964+
exir_ops.edge.cadence.conv1d.default,
965+
exir_ops.edge.cadence.conv2d.default,
966+
exir_ops.edge.cadence.conv3d.default,
954967
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
955968
}:
956969
return super().call_operator(op, args, kwargs, meta)
@@ -961,11 +974,11 @@ def call_operator(
961974
# Already in NHWC layout.
962975
return super().call_operator(op, args, kwargs, meta)
963976

964-
new_op = (
965-
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
966-
if quantized_op
967-
else exir_ops.edge.cadence.convolution.default
968-
)
977+
if quantized_op:
978+
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
979+
else:
980+
# Determine if 1D or 2D convolution based on op
981+
new_op = op
969982

970983
input_proxy = cast(ProxyValue, args[0])
971984
weight_proxy = cast(ProxyValue, args[1])
@@ -1038,7 +1051,9 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass):
10381051
# A map from the convolution op to the linear op that it should
10391052
# decompose to.
10401053
conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
1041-
exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
1054+
exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default,
1055+
exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default,
1056+
exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default,
10421057
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
10431058
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
10441059
}
@@ -1052,7 +1067,7 @@ def call_operator(self, op, args, kwargs, meta):
10521067
op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
10531068
or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
10541069
)
1055-
assert (len(args) == 8 and not quantized_op) or (
1070+
assert (len(args) == 7 and not quantized_op) or (
10561071
len(args) >= 12 and quantized_op
10571072
), "Inconsistent args for convolution"
10581073
(in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7]

0 commit comments

Comments
 (0)