Skip to content

Commit 6aa6dc0

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Add support for conv1d
Summary: Add explicit support for 1d conv with dtype selective builds Differential Revision: D82160616
1 parent f294074 commit 6aa6dc0

12 files changed

+1133
-4
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,26 @@
359359
- arg_meta: null
360360
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out
361361

362+
- func: cadence::quantized_conv_nchw_1d_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
363+
kernels:
364+
- arg_meta: null
365+
kernel_name: impl::reference::quantized_conv_nchw_1d_asym8sxsym8s_asym8s_per_tensor_out
366+
367+
- func: cadence::quantized_conv_nchw_1d_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
368+
kernels:
369+
- arg_meta: null
370+
kernel_name: impl::reference::quantized_conv_nchw_1d_asym8uxsym8u_asym8u_per_tensor_out
371+
372+
- func: cadence::quantized_conv_nhwc_1d_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
373+
kernels:
374+
- arg_meta: null
375+
kernel_name: impl::reference::quantized_conv_nhwc_1d_asym8sxsym8s_asym8s_per_tensor_out
376+
377+
- func: cadence::quantized_conv_nhwc_1d_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
378+
kernels:
379+
- arg_meta: null
380+
kernel_name: impl::reference::quantized_conv_nhwc_1d_asym8uxsym8u_asym8u_per_tensor_out
381+
362382
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
363383
kernels:
364384
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,26 @@
370370
- arg_meta: null
371371
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out
372372

373+
- func: cadence::quantized_conv_nchw_1d_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
374+
kernels:
375+
- arg_meta: null
376+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_1d_asym8sxsym8s_asym8s_per_tensor_out
377+
378+
- func: cadence::quantized_conv_nchw_1d_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
379+
kernels:
380+
- arg_meta: null
381+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_1d_asym8uxsym8u_asym8u_per_tensor_out
382+
383+
- func: cadence::quantized_conv_nhwc_1d_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
384+
kernels:
385+
- arg_meta: null
386+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_1d_asym8sxsym8s_asym8s_per_tensor_out
387+
388+
- func: cadence::quantized_conv_nhwc_1d_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
389+
kernels:
390+
- arg_meta: null
391+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_1d_asym8uxsym8u_asym8u_per_tensor_out
392+
373393
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
374394
kernels:
375395
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,30 @@
169169
lib.define(
170170
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
171171
)
172+
lib.define(
173+
"quantized_conv_nchw_1d_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
174+
)
175+
lib.define(
176+
"quantized_conv_nchw_1d_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
177+
)
178+
lib.define(
179+
"quantized_conv_nchw_1d_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
180+
)
181+
lib.define(
182+
"quantized_conv_nchw_1d_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
183+
)
184+
lib.define(
185+
"quantized_conv_nhwc_1d_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
186+
)
187+
lib.define(
188+
"quantized_conv_nhwc_1d_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
189+
)
190+
lib.define(
191+
"quantized_conv_nhwc_1d_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
192+
)
193+
lib.define(
194+
"quantized_conv_nhwc_1d_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
195+
)
172196
lib.define(
173197
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
174198
)
@@ -2149,6 +2173,150 @@ def roi_align_box_processor_meta(
21492173
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)
21502174

21512175

2176+
@register_fake("cadence::quantized_conv_nchw_1d_asym8sxsym8s_asym8s.per_tensor")
2177+
def quantized_conv_nchw_1d_asym8sxsym8s_asym8s_per_tensor_meta(
2178+
input: torch.Tensor,
2179+
weight: torch.Tensor,
2180+
bias: torch.Tensor,
2181+
stride: Tuple[int],
2182+
padding: Tuple[int],
2183+
dilation: Tuple[int],
2184+
groups: int,
2185+
in_zero_point: int,
2186+
weight_zero_point: int,
2187+
bias_scale: float,
2188+
output_scale: float,
2189+
output_zero_point: int,
2190+
out_multiplier: int,
2191+
out_shift: int,
2192+
) -> torch.Tensor:
2193+
assert input.dim() == 3 and weight.dim() == 3
2194+
assert (
2195+
input.dtype == torch.int8
2196+
and weight.dtype == torch.int8
2197+
and bias.dtype == torch.int32
2198+
)
2199+
out_channels, _, kernel_size = weight.shape
2200+
output_size = get_conv1d_output_size(
2201+
input.shape,
2202+
out_channels,
2203+
stride[1],
2204+
padding[1],
2205+
dilation[1],
2206+
kernel_size,
2207+
False,
2208+
)
2209+
return input.new_empty(output_size, dtype=input.dtype)
2210+
2211+
2212+
@register_fake("cadence::quantized_conv_nchw_1d_asym8uxsym8u_asym8u.per_tensor")
2213+
def quantized_conv_nchw_1d_asym8uxsym8u_asym8u_per_tensor_meta(
2214+
input: torch.Tensor,
2215+
weight: torch.Tensor,
2216+
bias: torch.Tensor,
2217+
stride: Tuple[int],
2218+
padding: Tuple[int],
2219+
dilation: Tuple[int],
2220+
groups: int,
2221+
in_zero_point: int,
2222+
weight_zero_point: int,
2223+
bias_scale: float,
2224+
output_scale: float,
2225+
output_zero_point: int,
2226+
out_multiplier: int,
2227+
out_shift: int,
2228+
) -> torch.Tensor:
2229+
assert input.dim() == 3 and weight.dim() == 3
2230+
assert (
2231+
input.dtype == torch.uint8
2232+
and weight.dtype == torch.uint8
2233+
and bias.dtype == torch.int32
2234+
)
2235+
out_channels, _, kernel_size = weight.shape
2236+
output_size = get_conv1d_output_size(
2237+
input.shape,
2238+
out_channels,
2239+
stride[1],
2240+
padding[1],
2241+
dilation[1],
2242+
kernel_size,
2243+
False,
2244+
)
2245+
return input.new_empty(output_size, dtype=input.dtype)
2246+
2247+
2248+
@register_fake("cadence::quantized_conv_nhwc_1d_asym8sxsym8s_asym8s.per_tensor")
2249+
def quantized_conv_nhwc_1d_asym8sxsym8s_asym8s_per_tensor_meta(
2250+
input: torch.Tensor,
2251+
weight: torch.Tensor,
2252+
bias: torch.Tensor,
2253+
stride: Tuple[int],
2254+
padding: Tuple[int],
2255+
dilation: Tuple[int],
2256+
groups: int,
2257+
in_zero_point: int,
2258+
weight_zero_point: int,
2259+
bias_scale: float,
2260+
output_scale: float,
2261+
output_zero_point: int,
2262+
out_multiplier: int,
2263+
out_shift: int,
2264+
) -> torch.Tensor:
2265+
assert input.dim() == 3 and weight.dim() == 3
2266+
assert (
2267+
input.dtype == torch.int8
2268+
and weight.dtype == torch.int8
2269+
and bias.dtype == torch.int32
2270+
)
2271+
out_channels, kernel_size, _ = weight.shape
2272+
output_size = get_conv1d_output_size(
2273+
input.shape,
2274+
out_channels,
2275+
stride[1],
2276+
padding[1],
2277+
dilation[1],
2278+
kernel_size,
2279+
True,
2280+
)
2281+
return input.new_empty(output_size, dtype=input.dtype)
2282+
2283+
2284+
@register_fake("cadence::quantized_conv_nhwc_1d_asym8uxsym8u_asym8u.per_tensor")
2285+
def quantized_conv_nhwc_1d_asym8uxsym8u_asym8u_per_tensor_meta(
2286+
input: torch.Tensor,
2287+
weight: torch.Tensor,
2288+
bias: torch.Tensor,
2289+
stride: Tuple[int],
2290+
padding: Tuple[int],
2291+
dilation: Tuple[int],
2292+
groups: int,
2293+
in_zero_point: int,
2294+
weight_zero_point: int,
2295+
bias_scale: float,
2296+
output_scale: float,
2297+
output_zero_point: int,
2298+
out_multiplier: int,
2299+
out_shift: int,
2300+
) -> torch.Tensor:
2301+
assert input.dim() == 3 and weight.dim() == 3
2302+
assert (
2303+
input.dtype == torch.uint8
2304+
and weight.dtype == torch.uint8
2305+
and bias.dtype == torch.int32
2306+
)
2307+
out_channels, kernel_size, _ = weight.shape
2308+
output_size = get_conv1d_output_size(
2309+
input.shape,
2310+
out_channels,
2311+
stride[1],
2312+
padding[1],
2313+
dilation[1],
2314+
kernel_size,
2315+
True,
2316+
)
2317+
return input.new_empty(output_size, dtype=input.dtype)
2318+
2319+
21522320
@register_fake("cadence::_softmax_f32_f32")
21532321
def softmax_f32_f32_meta(
21542322
self: torch.Tensor,

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,110 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
446446
1,
447447
)
448448

449+
def test_int8_dispatch_quantized_conv_nchw_1d(self) -> None:
450+
"""Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nchw"""
451+
x = torch.randint(-128, 127, (1, 3, 8), dtype=torch.int8)
452+
w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8)
453+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
454+
gm = single_op_builder(
455+
placeholders=(x, w, b),
456+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
457+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
458+
)
459+
p = CompileTimeTypeDispatchPass()
460+
gm = cast(PassResult, p(gm)).graph_module
461+
# Original op should be replaced
462+
self.assertEqual(
463+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
464+
0,
465+
)
466+
# Should be replaced with 1D int8 specific variant
467+
self.assertEqual(
468+
count_node(
469+
gm,
470+
exir_ops.edge.cadence.quantized_conv_nchw_1d_asym8sxsym8s_asym8s.per_tensor,
471+
),
472+
1,
473+
)
474+
475+
def test_uint8_dispatch_quantized_conv_nchw_1d(self) -> None:
476+
"""Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nchw"""
477+
x = torch.randint(0, 255, (1, 3, 8), dtype=torch.uint8)
478+
w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8)
479+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
480+
gm = single_op_builder(
481+
placeholders=(x, w, b),
482+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
483+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
484+
)
485+
p = CompileTimeTypeDispatchPass()
486+
gm = cast(PassResult, p(gm)).graph_module
487+
# Original op should be replaced
488+
self.assertEqual(
489+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
490+
0,
491+
)
492+
# Should be replaced with 1D uint8 specific variant
493+
self.assertEqual(
494+
count_node(
495+
gm,
496+
exir_ops.edge.cadence.quantized_conv_nchw_1d_asym8uxsym8u_asym8u.per_tensor,
497+
),
498+
1,
499+
)
500+
501+
def test_int8_dispatch_quantized_conv_nhwc_1d(self) -> None:
502+
"""Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nhwc"""
503+
x = torch.randint(-128, 127, (1, 8, 3), dtype=torch.int8)
504+
w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8)
505+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
506+
gm = single_op_builder(
507+
placeholders=(x, w, b),
508+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
509+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
510+
)
511+
p = CompileTimeTypeDispatchPass()
512+
gm = cast(PassResult, p(gm)).graph_module
513+
# Original op should be replaced
514+
self.assertEqual(
515+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
516+
0,
517+
)
518+
# Should be replaced with 1D int8 specific variant
519+
self.assertEqual(
520+
count_node(
521+
gm,
522+
exir_ops.edge.cadence.quantized_conv_nhwc_1d_asym8sxsym8s_asym8s.per_tensor,
523+
),
524+
1,
525+
)
526+
527+
def test_uint8_dispatch_quantized_conv_nhwc_1d(self) -> None:
528+
"""Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nhwc"""
529+
x = torch.randint(0, 255, (1, 8, 3), dtype=torch.uint8)
530+
w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8)
531+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
532+
gm = single_op_builder(
533+
placeholders=(x, w, b),
534+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
535+
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
536+
)
537+
p = CompileTimeTypeDispatchPass()
538+
gm = cast(PassResult, p(gm)).graph_module
539+
# Original op should be replaced
540+
self.assertEqual(
541+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
542+
0,
543+
)
544+
# Should be replaced with 1D uint8 specific variant
545+
self.assertEqual(
546+
count_node(
547+
gm,
548+
exir_ops.edge.cadence.quantized_conv_nhwc_1d_asym8uxsym8u_asym8u.per_tensor,
549+
),
550+
1,
551+
)
552+
449553
def test_int8_dispatch_quantized_add(self) -> None:
450554
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add"""
451555
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)

backends/cadence/aot/type_dispatch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,13 @@ def call_operator(
140140
else args[0].to_tensor().shape[-1]
141141
)
142142
is_depthwise = groups == input_channels
143-
144-
dilation = args[5]
145143
# pyre-ignore[16]: None has no attribute '__iter__'.
146-
is_dilated = any(d > 1 for d in dilation)
144+
is_dilated = any(d > 1 for d in args[5])
145+
is_1d = len(args[0].to_tensor().shape) == 3
147146

148-
if is_dilated:
147+
if is_1d:
148+
type_suffix = f"1d_{type_suffix}"
149+
elif is_dilated:
149150
type_suffix = f"dilated_{type_suffix}"
150151
elif is_depthwise:
151152
type_suffix = f"depthwise_{type_suffix}"

0 commit comments

Comments
 (0)