|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring |
| 18 | +import tvm |
| 19 | +import tvm.testing |
| 20 | +from tvm import meta_schedule as ms |
| 21 | +from tvm import tir |
| 22 | +from tvm.script import tir as T |
| 23 | + |
| 24 | + |
| 25 | +def _create_context(mod, target) -> ms.TuneContext: |
| 26 | + return ms.TuneContext( |
| 27 | + mod=mod, |
| 28 | + target=target, |
| 29 | + space_generator=ms.space_generator.PostOrderApply( |
| 30 | + sch_rules=[], |
| 31 | + postprocs=[ms.postproc.VerifyVTCMLimit()], |
| 32 | + mutator_probs={}, |
| 33 | + ), |
| 34 | + task_name="test", |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant |
| 39 | +# fmt: off |
| 40 | + |
| 41 | + |
| 42 | +@tvm.script.ir_module |
| 43 | +class Conv2dNCHWcVTCM: |
| 44 | + @T.prim_func |
| 45 | + def main(p0: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"], p1: T.Buffer[(T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"], conv2d_NCHWc_int8: T.Buffer[(T.int64(1), T.int64(2), T.int64(54), T.int64(54), T.int64(32)), "int32"]): |
| 46 | + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) |
| 47 | + p0_global_vtcm = T.alloc_buffer([T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)], dtype="uint8", scope="global.vtcm") |
| 48 | + p1_global_vtcm = T.alloc_buffer([T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)], dtype="uint8", scope="global.vtcm") |
| 49 | + for n_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): |
| 50 | + for oc_chunk_0, oh_0, ow_0, oc_block_0_0 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(1)): |
| 51 | + for oc_chunk_1_init, oh_1_init, ow_1_init, oc_chunk_2_init, oh_2_init, ow_2_init in T.grid(T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(1), T.int64(9)): |
| 52 | + with T.block("conv2d_NCHWc_int8_o_init"): |
| 53 | + v_n = T.axis.spatial(T.int64(1), T.int64(0)) |
| 54 | + v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1_init + oc_chunk_2_init + oc_chunk_0) |
| 55 | + v_oh = T.axis.spatial(T.int64(54), oh_2_init + oh_0 * T.int64(27) + oh_1_init) |
| 56 | + v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1_init * T.int64(9) + ow_2_init) |
| 57 | + v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0)) |
| 58 | + T.reads() |
| 59 | + T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]) |
| 60 | + for oc_block_1 in T.vectorized(T.int64(32)): |
| 61 | + with T.block("conv2d_NCHWc_int8_init"): |
| 62 | + v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1) |
| 63 | + T.reads() |
| 64 | + T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init]) |
| 65 | + conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] = 0 |
| 66 | + for kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused in T.serial(T.int64(2), annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): |
| 67 | + for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(26912)): |
| 68 | + with T.block("p0_global.vtcm"): |
| 69 | + v0 = T.axis.spatial(T.int64(1), T.int64(0)) |
| 70 | + v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_fused // T.int64(13456)) |
| 71 | + v2 = T.axis.spatial(T.int64(56), oh_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(13456) // T.int64(464)) |
| 72 | + v3 = T.axis.spatial(T.int64(56), ow_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(464) // T.int64(16)) |
| 73 | + v4 = T.axis.spatial(T.int64(32), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(16) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(16)) |
| 74 | + T.reads(p0[v0, v1, v2, v3, v4]) |
| 75 | + T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) |
| 76 | + p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] |
| 77 | + for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(9216)): |
| 78 | + with T.block("p1_global.vtcm"): |
| 79 | + v0 = T.axis.spatial(T.int64(2), oc_chunk_0) |
| 80 | + v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(4608)) |
| 81 | + v2 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4608) // T.int64(1536)) |
| 82 | + v3 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(1536) // T.int64(512)) |
| 83 | + v4 = T.axis.spatial(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(512) // T.int64(128)) |
| 84 | + v5 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4)) |
| 85 | + v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4)) |
| 86 | + T.reads(p1[v0, v1, v2, v3, v4, v5, v6]) |
| 87 | + T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) |
| 88 | + p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6] |
| 89 | + for n_1, oc_chunk_1, oh_1, ow_1, oc_block_0_1, kh_1, kw_1, ic_outer_1, ic_f_inner_1, ic_s_inner_0_1, n_2, oc_chunk_2, oh_2, ow_2, oc_block_0_2 in T.grid(T.int64(1), T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(3), T.int64(3), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(9), T.int64(1)): |
| 90 | + with T.block("conv2d_NCHWc_int8_o_update"): |
| 91 | + v_n = T.axis.spatial(T.int64(1), T.int64(0)) |
| 92 | + v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1 + oc_chunk_2 + oc_chunk_0) |
| 93 | + v_oh = T.axis.spatial(T.int64(54), oh_2 + oh_0 * T.int64(27) + oh_1) |
| 94 | + v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1 * T.int64(9) + ow_2) |
| 95 | + v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0)) |
| 96 | + v_kh, v_kw, v_ic_outer = T.axis.remap("RRR", [kh_1, kw_1, ic_outer_1]) |
| 97 | + v_ic_f_inner = T.axis.reduce(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ic_f_inner_1) |
| 98 | + v_ic_s_inner_o = T.axis.reduce(T.int64(1), T.int64(0)) |
| 99 | + T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4)], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, T.int64(0) : T.int64(32), T.int64(0) : T.int64(4)]) |
| 100 | + T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]) |
| 101 | + for oc_block_1, ic_s_inner_1 in T.grid(T.int64(32), T.int64(4)): |
| 102 | + with T.block("conv2d_NCHWc_int8"): |
| 103 | + v_oc_block_i, v_ic_s_inner_i = T.axis.remap("SR", [oc_block_1, ic_s_inner_1]) |
| 104 | + T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i]) |
| 105 | + T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i]) |
| 106 | + T.block_attr({"meta_schedule.tiling_structure":"SRSRS"}) |
| 107 | + conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] = conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] + T.Cast("int32", p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i]) * T.Cast("int32", p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i]) |
| 108 | + |
| 109 | +#fmt on |
| 110 | + |
| 111 | + |
| 112 | +def test_conv2d_vtcm(): |
| 113 | + def get_target(vtcm_cap): |
| 114 | + target = tvm.target.hexagon("v68", vtcm_capacity=vtcm_cap) |
| 115 | + return tvm.target.Target(target, host=target) |
| 116 | + |
| 117 | + sch = tir.Schedule(Conv2dNCHWcVTCM, debug_mask="all") |
| 118 | + |
| 119 | + ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(70000)) |
| 120 | + assert not ctx.space_generator.postprocs[0].apply(sch) |
| 121 | + |
| 122 | + ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(75000)) |
| 123 | + assert ctx.space_generator.postprocs[0].apply(sch) |
| 124 | + |
| 125 | + |
| 126 | +if __name__ == "__main__": |
| 127 | + tvm.testing.main() |
0 commit comments