-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Topi] Cortex-M DSP support #9233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 158 commits
ee5c4bf
6800b8d
30169d0
a34ddcb
5560645
3fe5aa8
ae4d9e3
5049d86
0e24218
407b8ff
f853977
ae4c851
0493721
a892bca
d1a33e0
ecf7bb2
0ceb612
5f257ca
4e8f56e
34b1b6f
a81bf16
32364a7
820969c
e6ac0d7
dce7874
4327ce2
176d2c4
c19d943
1860689
0f01b71
6dc7e11
c1c9f97
4f8ab09
25156b0
0f68a39
a73df72
f4a4d82
aeef970
47a0711
02ff764
1ebe1d8
9196909
690dd17
aea91dd
5f8f587
b4b01d4
fcbe0a3
3212f44
e7da13e
5351360
83e254d
1b03c68
c97fcf3
facd7df
49976a2
f1ebae9
2abea65
f044427
500ef88
4617429
b094dbf
795c1e2
e911e5c
8a05a55
86f4906
be1aa86
414b431
ed695fd
fc5556a
4023481
3dd6cb8
2f99b9e
a26c811
dad1cbd
5fcec5e
2ef763f
a1794c3
750f0e3
1fb5e14
b28719e
e442e11
9e186ab
38c05e9
d69b725
5531f6b
4181a73
327179f
caa3f66
c12fb05
b024a1a
6908a95
7b50e1d
5755f17
2e71d7a
e333d0e
daf4743
4d73775
dd0c0df
0854f40
080a7a9
748d1a6
c198db9
74d693f
f4f410f
1640f0f
68cfdb3
66bcec9
c2698c8
5d1e0d2
4cfcad0
8307915
7620afc
f29b131
647e3ff
64efb87
a7adf41
37d3c81
a6c4e07
5963b4b
7a6217c
eafab3e
0482186
32ede71
4398d5f
fd657b9
8f036c3
2811f99
0b1c836
c18164a
950c5ea
ab6f111
a52e09e
1190245
f015b88
5aaba80
3e80c0c
73569f2
38f6c5a
66a4587
3c4e2e2
766db57
71d8ff6
b7dc932
977559b
6986fe1
cab48e7
7acbb63
a52a48a
771ae0d
feacab6
9270261
6a0c573
7006824
0f4e1ea
cba97aa
a99440b
c082335
94c8f3a
3e54254
cc42c4d
66dbd5d
638846a
4de1c16
3c81bea
388b62a
53c847c
3d577ca
9c403fe
fb17329
4803e19
975aff4
6cec82c
a593538
bc2b4ec
be7078a
9d25cf6
1407614
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target): | |
| return topi.arm_cpu.schedule_concatenate(outs) | ||
|
|
||
|
|
||
| @schedule_pool.register(["arm_cpu", "micro_dev"]) | ||
| def schedule_pool_arm_cpu(attrs, outs, target): | ||
| """schedule pooling ops arm cpu""" | ||
| layout = attrs.layout | ||
| isa = arm_isa.IsaAnalyzer(target) | ||
| avg_pool = hasattr(attrs, "count_include_pad") | ||
sergio-grovety marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| with target: | ||
| if ( | ||
| avg_pool | ||
| and layout in ("NCW", "NCHW") | ||
| and "SMLAD" in isa | ||
|
||
| or not avg_pool | ||
| and "SSUB8" in isa | ||
| and "SEL" in isa | ||
| and layout in ("NWC", "NHWC") | ||
| ): | ||
| return topi.arm_cpu.schedule_pool(outs, layout) | ||
| return topi.generic.schedule_pool(outs, layout) | ||
|
|
||
|
|
||
| @conv2d_strategy.register(["arm_cpu", "micro_dev"]) | ||
| def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): | ||
| """conv2d arm cpu strategy""" | ||
|
|
@@ -415,3 +435,67 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target): | |
| name="bitserial_dense.arm_cpu", | ||
| ) | ||
| return strategy | ||
|
|
||
|
|
||
| @dense_strategy.register(["arm_cpu", "micro_dev"]) | ||
| def schedule_dense_arm_cpu(attrs, inputs, out_type, target): | ||
| """dense arm cpu strategy""" | ||
| strategy = _op.OpStrategy() | ||
| isa = arm_isa.IsaAnalyzer(target) | ||
| if "SMLAD" in isa: | ||
| strategy.add_implementation( | ||
| wrap_compute_dense(topi.nn.dense), | ||
| wrap_topi_schedule(topi.arm_cpu.schedule_dense_direct_simd), | ||
| name="dense_direct_simd.micro_dev", | ||
| ) | ||
| else: | ||
| strategy.add_implementation( | ||
| wrap_compute_dense(topi.nn.dense), | ||
| wrap_topi_schedule(topi.generic.schedule_dense), | ||
| name="dense.generic", | ||
| ) | ||
| return strategy | ||
|
|
||
|
|
||
| @conv1d_strategy.register("arm_cpu") | ||
| def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target): | ||
| """conv1d strategy""" | ||
| strategy = _op.OpStrategy() | ||
| layout = attrs.data_layout | ||
| kernel_layout = attrs.kernel_layout | ||
| dilation = get_const_tuple(attrs.dilation) | ||
| if dilation[0] < 1: | ||
| raise ValueError("dilation should be a positive value") | ||
|
|
||
| isa = arm_isa.IsaAnalyzer(target) | ||
|
|
||
| if kernel_layout == "WOI": | ||
| if layout == "NWC" and "SMLAD" in isa: | ||
| strategy.add_implementation( | ||
| wrap_compute_conv1d(topi.arm_cpu.conv1d_direct_simd), | ||
| wrap_topi_schedule(topi.arm_cpu.schedule_conv1d_direct_simd), | ||
sergio-grovety marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| name="conv1d_direct_simd.micro_dev", | ||
| ) | ||
| else: | ||
| raise RuntimeError( | ||
| "Unsupported kernel layout {} for conv1d {} for arm cpu.".format( | ||
| kernel_layout, layout | ||
| ) | ||
| ) | ||
| elif layout == "NCW": | ||
| strategy.add_implementation( | ||
| wrap_compute_conv1d(topi.nn.conv1d_ncw), | ||
| wrap_topi_schedule(topi.generic.schedule_conv1d_ncw), | ||
| name="conv1d_ncw.generic", | ||
| ) | ||
| elif layout == "NWC": | ||
| strategy.add_implementation( | ||
| wrap_compute_conv1d(topi.nn.conv1d_nwc), | ||
| wrap_topi_schedule(topi.generic.schedule_conv1d_nwc), | ||
| name="conv1d_nwc.generic", | ||
| ) | ||
| else: | ||
| raise RuntimeError( | ||
| "Unsupported kernel layout {} for conv1d {} for arm cpu.".format(kernel_layout, layout) | ||
| ) | ||
| return strategy | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -374,6 +374,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): | |
| attrs["kernel_layout"], | ||
| attrs["groups"], | ||
| ) | ||
|
|
||
| # Use int8 for Cortex-M7 | ||
|
||
| use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC" | ||
| if use_int8_on_arm or is_fast_int8_on_arm(): | ||
| return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,18 +16,24 @@ | |||||||||
| # under the License. | ||||||||||
| """Defines functions to analyze available opcodes in the ARM ISA.""" | ||||||||||
|
|
||||||||||
| import argparse | ||||||||||
|
|
||||||||||
| ARM_ISA_MAP = { | ||||||||||
| "armv7e-m": ["SMLAD"], | ||||||||||
| "armv7e-m": ["SMLAD", "SSUB8", "SEL"], | ||||||||||
|
||||||||||
| "armv8-m": ["SMLAD", "SSUB8", "SEL"], | ||||||||||
|
||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class IsaAnalyzer(object): | ||||||||||
| """Checks ISA support for given target""" | ||||||||||
|
|
||||||||||
| def __init__(self, target): | ||||||||||
| self.target = target | ||||||||||
| # TODO: actually parse -mcpu | ||||||||||
| arch = "armv7e-m" | ||||||||||
| self._isa_map = ARM_ISA_MAP[arch] | ||||||||||
| parser = argparse.ArgumentParser() | ||||||||||
|
||||||||||
| parser = argparse.ArgumentParser() | |
| target = tvm.target.Target(target) | |
| march = target.attrs.get("-march", None) | |
| self._isa_map = ARM_ISA_MAP[march] if march is not None else [] |
(also need to delete the following lines 33-36--suggestion didn't quite get the diff)
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -674,6 +674,18 @@ def requires_opencl(*args): | |||||||||||||||
| return _compose(args, _requires_opencl) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def requires_corstone300(*args): | ||||||||||||||||
| """Mark a test as requiring the corstone300 FVP | ||||||||||||||||
|
|
||||||||||||||||
| Parameters | ||||||||||||||||
| ---------- | ||||||||||||||||
| f : function | ||||||||||||||||
| Function to mark | ||||||||||||||||
| """ | ||||||||||||||||
| _requires_corstone300 = [pytest.mark.corstone300] | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it also needs a skipif() in here. mark is on-by-default iiuc.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm adding a pytest.marker.skip() to the tests marked requires_corstone300 in the tests/python/conftest.py depending on the value of the "--enable-corstone300-tests" flag. Haven't found a better way to control tests behavior since we don't have device_enabled() for corstone300 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need a better way of controlling this - possibly something @Mousius could comment on here ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i discussed this with @grant-arm a bit and it seems the consensus was that there isn't a good way to auto-detect the FVP. perhaps we've missed something though?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depending on how these tests are ran, we could use the slightly icky AOT skip logic: tvm/tests/python/relay/aot/aot_test_utils.py Lines 196 to 201 in f4dae23
This would at least automate it if these tests are designed to run in CPU containers. Otherwise, we should just be able to check for the path since we know exactly where we're checking it out in the container:
|
||||||||||||||||
| return _compose(args, _requires_corstone300) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def requires_rocm(*args): | ||||||||||||||||
| """Mark a test as requiring the rocm runtime. | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,3 +18,4 @@ | |
|
|
||
|
|
||
| from . import conv2d | ||
| from . import dense | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Conv1d implementations for cortex-m7.""" | ||
|
||
|
|
||
| from . import direct_simd | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we mean by micro_dev here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was wrongly copied from similar schedule declaration, fixed in fb17329