|
19 | 19 | import re |
20 | 20 | import logging |
21 | 21 |
|
22 | | -from tvm import topi |
| 22 | +from tvm import relay, topi |
23 | 23 | from ....target import arm_isa |
24 | 24 | from ....topi.generic import conv2d as conv2d_generic |
25 | 25 | from .generic import * |
@@ -49,6 +49,25 @@ def schedule_concatenate_arm_cpu(_, outs, target): |
49 | 49 | return topi.arm_cpu.schedule_concatenate(outs) |
50 | 50 |
|
51 | 51 |
|
| 52 | +@schedule_pool.register(["arm_cpu"]) |
| 53 | +def schedule_pool_arm_cpu(attrs, outs, target): |
| 54 | + """schedule pooling ops arm cpu""" |
| 55 | + layout = attrs.layout |
| 56 | + isa = arm_isa.IsaAnalyzer(target) |
| 57 | + avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs) |
| 58 | + with target: |
| 59 | + if ( |
| 60 | + avg_pool |
| 61 | + and isa.has_dsp_support |
| 62 | + and layout in ("NCW", "NCHW") |
| 63 | + or not avg_pool |
| 64 | + and isa.has_dsp_support |
| 65 | + and layout in ("NWC", "NHWC") |
| 66 | + ): |
| 67 | + return topi.arm_cpu.schedule_pool(outs, layout) |
| 68 | + return topi.generic.schedule_pool(outs, layout) |
| 69 | + |
| 70 | + |
52 | 71 | @conv2d_strategy.register(["arm_cpu", "micro_dev"]) |
53 | 72 | def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): |
54 | 73 | """conv2d arm cpu strategy""" |
@@ -128,11 +147,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): |
128 | 147 | name="conv2d_hwcn.generic", |
129 | 148 | ) |
130 | 149 | elif layout == "NHWC": |
131 | | - if "SMLAD" in isa and kernel_layout == "HWOI": |
| 150 | + if isa.has_dsp_support and kernel_layout == "HWOI": |
132 | 151 | strategy.add_implementation( |
133 | | - wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_direct_simd), |
134 | | - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_direct_simd), |
135 | | - name="conv2d_nhwc_direct_simd.micro_dev", |
| 152 | + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_dsp), |
| 153 | + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_dsp), |
| 154 | + name="conv2d_nhwc_dsp.micro_dev", |
136 | 155 | ) |
137 | 156 | elif kernel_layout == "HWIO": |
138 | 157 | is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() |
@@ -415,3 +434,67 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target): |
415 | 434 | name="bitserial_dense.arm_cpu", |
416 | 435 | ) |
417 | 436 | return strategy |
| 437 | + |
| 438 | + |
| 439 | +@dense_strategy.register(["arm_cpu"]) |
| 440 | +def schedule_dense_arm_cpu(attrs, inputs, out_type, target): |
| 441 | + """dense arm cpu strategy""" |
| 442 | + strategy = _op.OpStrategy() |
| 443 | + isa = arm_isa.IsaAnalyzer(target) |
| 444 | + if isa.has_dsp_support: |
| 445 | + strategy.add_implementation( |
| 446 | + wrap_compute_dense(topi.nn.dense), |
| 447 | + wrap_topi_schedule(topi.arm_cpu.schedule_dense_dsp), |
| 448 | + name="dense_dsp", |
| 449 | + ) |
| 450 | + else: |
| 451 | + strategy.add_implementation( |
| 452 | + wrap_compute_dense(topi.nn.dense), |
| 453 | + wrap_topi_schedule(topi.generic.schedule_dense), |
| 454 | + name="dense.generic", |
| 455 | + ) |
| 456 | + return strategy |
| 457 | + |
| 458 | + |
| 459 | +@conv1d_strategy.register("arm_cpu") |
| 460 | +def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target): |
| 461 | + """conv1d strategy""" |
| 462 | + strategy = _op.OpStrategy() |
| 463 | + layout = attrs.data_layout |
| 464 | + kernel_layout = attrs.kernel_layout |
| 465 | + dilation = get_const_tuple(attrs.dilation) |
| 466 | + if dilation[0] < 1: |
| 467 | + raise ValueError("dilation should be a positive value") |
| 468 | + |
| 469 | + isa = arm_isa.IsaAnalyzer(target) |
| 470 | + |
| 471 | + if kernel_layout == "WOI": |
| 472 | + if layout == "NWC" and isa.has_dsp_support: |
| 473 | + strategy.add_implementation( |
| 474 | + wrap_compute_conv1d(topi.arm_cpu.conv1d_nwc_dsp), |
| 475 | + wrap_topi_schedule(topi.arm_cpu.schedule_conv1d_nwc_dsp), |
| 476 | + name="conv1d_dsp", |
| 477 | + ) |
| 478 | + else: |
| 479 | + raise RuntimeError( |
| 480 | + "Unsupported kernel layout {} for conv1d {} for arm cpu.".format( |
| 481 | + kernel_layout, layout |
| 482 | + ) |
| 483 | + ) |
| 484 | + elif layout == "NCW": |
| 485 | + strategy.add_implementation( |
| 486 | + wrap_compute_conv1d(topi.nn.conv1d_ncw), |
| 487 | + wrap_topi_schedule(topi.generic.schedule_conv1d_ncw), |
| 488 | + name="conv1d_ncw.generic", |
| 489 | + ) |
| 490 | + elif layout == "NWC": |
| 491 | + strategy.add_implementation( |
| 492 | + wrap_compute_conv1d(topi.nn.conv1d_nwc), |
| 493 | + wrap_topi_schedule(topi.generic.schedule_conv1d_nwc), |
| 494 | + name="conv1d_nwc.generic", |
| 495 | + ) |
| 496 | + else: |
| 497 | + raise RuntimeError( |
| 498 | + "Unsupported kernel layout {} for conv1d {} for arm cpu.".format(kernel_layout, layout) |
| 499 | + ) |
| 500 | + return strategy |
0 commit comments