Skip to content

Commit 76c78a9

Browse files
sergio-grovetySergey SmirnovEkaterina BernMikhail TrubnikovGermanTretiakov
authored
[Topi] Cortex-M DSP support (#9233)
Co-authored-by: Sergey Smirnov <[email protected]> Co-authored-by: Ekaterina Bern <[email protected]> Co-authored-by: Mikhail Trubnikov <[email protected]> Co-authored-by: German Tretiakov <[email protected]> Co-authored-by: Ilya Gozman <[email protected]> Co-authored-by: Alexey.Yazev <[email protected]> Co-authored-by: Ilya Gozman <[email protected]>
1 parent 13f54e0 commit 76c78a9

File tree

26 files changed

+1387
-275
lines changed

26 files changed

+1387
-275
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
import logging
2121

22-
from tvm import topi
22+
from tvm import relay, topi
2323
from ....target import arm_isa
2424
from ....topi.generic import conv2d as conv2d_generic
2525
from .generic import *
@@ -49,6 +49,25 @@ def schedule_concatenate_arm_cpu(_, outs, target):
4949
return topi.arm_cpu.schedule_concatenate(outs)
5050

5151

52+
@schedule_pool.register(["arm_cpu"])
53+
def schedule_pool_arm_cpu(attrs, outs, target):
54+
"""schedule pooling ops arm cpu"""
55+
layout = attrs.layout
56+
isa = arm_isa.IsaAnalyzer(target)
57+
avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
58+
with target:
59+
if (
60+
avg_pool
61+
and isa.has_dsp_support
62+
and layout in ("NCW", "NCHW")
63+
or not avg_pool
64+
and isa.has_dsp_support
65+
and layout in ("NWC", "NHWC")
66+
):
67+
return topi.arm_cpu.schedule_pool(outs, layout)
68+
return topi.generic.schedule_pool(outs, layout)
69+
70+
5271
@conv2d_strategy.register(["arm_cpu", "micro_dev"])
5372
def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
5473
"""conv2d arm cpu strategy"""
@@ -128,11 +147,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
128147
name="conv2d_hwcn.generic",
129148
)
130149
elif layout == "NHWC":
131-
if "SMLAD" in isa and kernel_layout == "HWOI":
150+
if isa.has_dsp_support and kernel_layout == "HWOI":
132151
strategy.add_implementation(
133-
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_direct_simd),
134-
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_direct_simd),
135-
name="conv2d_nhwc_direct_simd.micro_dev",
152+
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_dsp),
153+
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_dsp),
154+
name="conv2d_nhwc_dsp.micro_dev",
136155
)
137156
elif kernel_layout == "HWIO":
138157
is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
@@ -415,3 +434,67 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
415434
name="bitserial_dense.arm_cpu",
416435
)
417436
return strategy
437+
438+
439+
@dense_strategy.register(["arm_cpu"])
440+
def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
441+
"""dense arm cpu strategy"""
442+
strategy = _op.OpStrategy()
443+
isa = arm_isa.IsaAnalyzer(target)
444+
if isa.has_dsp_support:
445+
strategy.add_implementation(
446+
wrap_compute_dense(topi.nn.dense),
447+
wrap_topi_schedule(topi.arm_cpu.schedule_dense_dsp),
448+
name="dense_dsp",
449+
)
450+
else:
451+
strategy.add_implementation(
452+
wrap_compute_dense(topi.nn.dense),
453+
wrap_topi_schedule(topi.generic.schedule_dense),
454+
name="dense.generic",
455+
)
456+
return strategy
457+
458+
459+
@conv1d_strategy.register("arm_cpu")
460+
def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
461+
"""conv1d strategy"""
462+
strategy = _op.OpStrategy()
463+
layout = attrs.data_layout
464+
kernel_layout = attrs.kernel_layout
465+
dilation = get_const_tuple(attrs.dilation)
466+
if dilation[0] < 1:
467+
raise ValueError("dilation should be a positive value")
468+
469+
isa = arm_isa.IsaAnalyzer(target)
470+
471+
if kernel_layout == "WOI":
472+
if layout == "NWC" and isa.has_dsp_support:
473+
strategy.add_implementation(
474+
wrap_compute_conv1d(topi.arm_cpu.conv1d_nwc_dsp),
475+
wrap_topi_schedule(topi.arm_cpu.schedule_conv1d_nwc_dsp),
476+
name="conv1d_dsp",
477+
)
478+
else:
479+
raise RuntimeError(
480+
"Unsupported kernel layout {} for conv1d {} for arm cpu.".format(
481+
kernel_layout, layout
482+
)
483+
)
484+
elif layout == "NCW":
485+
strategy.add_implementation(
486+
wrap_compute_conv1d(topi.nn.conv1d_ncw),
487+
wrap_topi_schedule(topi.generic.schedule_conv1d_ncw),
488+
name="conv1d_ncw.generic",
489+
)
490+
elif layout == "NWC":
491+
strategy.add_implementation(
492+
wrap_compute_conv1d(topi.nn.conv1d_nwc),
493+
wrap_topi_schedule(topi.generic.schedule_conv1d_nwc),
494+
name="conv1d_nwc.generic",
495+
)
496+
else:
497+
raise RuntimeError(
498+
"Unsupported kernel layout {} for conv1d {} for arm cpu.".format(kernel_layout, layout)
499+
)
500+
return strategy

python/tvm/target/arm_isa.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,24 @@
1616
# under the License.
1717
"""Defines functions to analyze available opcodes in the ARM ISA."""
1818

19+
import tvm.target
1920

20-
ARM_ISA_MAP = {
21-
"armv7e-m": ["SMLAD"],
22-
}
21+
22+
ARM_MPROFILE_DSP_SUPPORT_LIST = [
23+
"cortex-m7",
24+
"cortex-m4",
25+
"cortex-m33",
26+
"cortex-m35p",
27+
"cortex-m55",
28+
]
2329

2430

2531
class IsaAnalyzer(object):
32+
"""Checks ISA support for given target"""
33+
2634
def __init__(self, target):
27-
self.target = target
28-
# TODO: actually parse -mcpu
29-
arch = "armv7e-m"
30-
self._isa_map = ARM_ISA_MAP[arch]
35+
self.target = tvm.target.Target(target)
3136

32-
def __contains__(self, instruction):
33-
return instruction in self._isa_map
37+
@property
38+
def has_dsp_support(self):
39+
return self.target.mcpu is not None and self.target.mcpu in ARM_MPROFILE_DSP_SUPPORT_LIST

python/tvm/testing/plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"llvm": "mark a test as requiring llvm",
5050
"ethosn": "mark a test as requiring ethosn",
5151
"hexagon": "mark a test as requiring hexagon",
52+
"corstone300": "mark a test as requiring Corstone300 FVP",
5253
}
5354

5455

python/tvm/testing/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,18 @@ def requires_opencl(*args):
674674
return _compose(args, _requires_opencl)
675675

676676

677+
def requires_corstone300(*args):
678+
"""Mark a test as requiring the corstone300 FVP
679+
680+
Parameters
681+
----------
682+
f : function
683+
Function to mark
684+
"""
685+
_requires_corstone300 = [pytest.mark.corstone300]
686+
return _compose(args, _requires_corstone300)
687+
688+
677689
def requires_rocm(*args):
678690
"""Mark a test as requiring the rocm runtime.
679691

python/tvm/topi/arm_cpu/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=wildcard-import
1818
"""Schedule for ARM CPU"""
1919

20+
from .conv1d import *
2021
from .conv2d import *
2122
from .depthwise_conv2d import *
2223
from .conv2d_transpose import *
@@ -25,5 +26,6 @@
2526
from .bitserial_conv2d import *
2627
from .bitserial_dense import *
2728
from .injective import *
28-
from . import cortex_m7
2929
from .group_conv2d import *
30+
from .pooling import *
31+
from .dense import *

python/tvm/topi/arm_cpu/conv1d.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
18+
"""Conv1D schedule for ARM CPU"""
19+
from __future__ import absolute_import as _abs
20+
21+
from tvm import autotvm
22+
23+
from .mprofile.dsp.conv1d import (
24+
conv1d_nwc_dsp_compute,
25+
conv1d_nwc_dsp_schedule,
26+
)
27+
28+
29+
@autotvm.register_topi_compute("conv1d_nwc_dsp.arm_cpu")
30+
def conv1d_nwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
31+
"""Compute conv1d with v7e-m DSP instructions."""
32+
return conv1d_nwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype)
33+
34+
35+
@autotvm.register_topi_schedule("conv1d_nwc_dsp.arm_cpu")
36+
def schedule_conv1d_nwc_dsp(cfg, outs):
37+
return conv1d_nwc_dsp_schedule(cfg, outs)

python/tvm/topi/arm_cpu/conv2d.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
schedule_conv2d_spatial_pack_nchw,
3434
schedule_conv2d_spatial_pack_nhwc,
3535
)
36-
from .cortex_m7.conv2d import direct_simd
36+
from .mprofile.dsp.conv2d import (
37+
conv2d_nhwc_dsp_compute,
38+
conv2d_nhwc_dsp_schedule,
39+
)
3740

3841

3942
@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu")
@@ -505,15 +508,13 @@ def _callback(op):
505508
return s
506509

507510

508-
@autotvm.register_topi_compute("conv2d_nhwc_direct_simd.arm_cpu")
509-
def conv2d_nhwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):
510-
"""Compute conv2d_nhwc with SIMD (v7e-m)."""
511-
return direct_simd.conv2d_nhwc_direct_simd_compute(
512-
cfg, data, kernel, strides, padding, dilation, out_dtype
513-
)
511+
@autotvm.register_topi_compute("conv2d_nhwc_dsp.arm_cpu")
512+
def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
513+
"""Compute conv2d_nhwc with v7e-m DSP instructions."""
514+
return conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype)
514515

515516

516-
@autotvm.register_topi_schedule("conv2d_nhwc_direct_simd.arm_cpu")
517-
def schedule_conv2d_nhwc_direct_simd(cfg, outs):
518-
"""Create schedule for conv2d_nhwc_direct_simd"""
519-
return direct_simd.conv2d_nhwc_direct_simd_schedule(cfg, outs)
517+
@autotvm.register_topi_schedule("conv2d_nhwc_dsp.arm_cpu")
518+
def schedule_conv2d_nhwc_dsp(cfg, outs):
519+
"""Create schedule for conv2d_nhwc_dsp"""
520+
return conv2d_nhwc_dsp_schedule(cfg, outs)

0 commit comments

Comments
 (0)