|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring |
| 18 | + |
| 19 | +import sys |
| 20 | +import pytest |
| 21 | +import tvm |
| 22 | +from tvm import tir |
| 23 | +from tvm.meta_schedule import TuneContext |
| 24 | +from tvm.meta_schedule.postproc import VerifyGPUCode |
| 25 | +from tvm.script import tir as T |
| 26 | +from tvm.target import Target |
| 27 | + |
| 28 | + |
| 29 | +def _target() -> Target: |
| 30 | + return Target("nvidia/geforce-rtx-3080") |
| 31 | + |
| 32 | + |
| 33 | +def _create_context(mod, target) -> TuneContext: |
| 34 | + ctx = TuneContext( |
| 35 | + mod=mod, |
| 36 | + target=target, |
| 37 | + postprocs=[ |
| 38 | + VerifyGPUCode(), |
| 39 | + ], |
| 40 | + task_name="test", |
| 41 | + ) |
| 42 | + for rule in ctx.postprocs: |
| 43 | + rule.initialize_with_tune_context(ctx) |
| 44 | + return ctx |
| 45 | + |
| 46 | + |
| 47 | +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant |
| 48 | +# fmt: off |
| 49 | + |
| 50 | +@tvm.script.ir_module |
| 51 | +class Conv2dCuda0: |
| 52 | + @T.prim_func |
| 53 | + def main(a: T.handle, b: T.handle) -> None: |
| 54 | + # function attr dict |
| 55 | + T.func_attr({"global_symbol": "main", "T.noalias": True}) |
| 56 | + # var definition |
| 57 | + threadIdx_x = T.env_thread("threadIdx.x") |
| 58 | + threadIdx_y = T.env_thread("threadIdx.y") |
| 59 | + blockIdx_x = T.env_thread("blockIdx.x") |
| 60 | + blockIdx_y = T.env_thread("blockIdx.y") |
| 61 | + blockIdx_z = T.env_thread("blockIdx.z") |
| 62 | + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") |
| 63 | + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") |
| 64 | + # body |
| 65 | + T.launch_thread(blockIdx_z, 196) |
| 66 | + B_local = T.allocate([64], "float32", "local") |
| 67 | + Apad_shared = T.allocate([512], "float32", "shared") |
| 68 | + Apad_shared_local = T.allocate([8], "float32", "local") |
| 69 | + T.launch_thread(blockIdx_y, 8) |
| 70 | + T.launch_thread(blockIdx_x, 4) |
| 71 | + T.launch_thread(threadIdx_y, 8) |
| 72 | + T.launch_thread(threadIdx_x, 8) |
| 73 | + for ff_c_init, nn_c_init in T.grid(8, 8): |
| 74 | + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) |
| 75 | + for rc_outer, ry, rx in T.grid(32, 3, 3): |
| 76 | + for ax3_inner_outer in T.serial(0, 2): |
| 77 | + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) |
| 78 | + for rc_inner in T.serial(0, 8): |
| 79 | + for ax3 in T.serial(0, 8): |
| 80 | + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) |
| 81 | + for ff_c, nn_c in T.grid(8, 8): |
| 82 | + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) |
| 83 | + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): |
| 84 | + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on |
| 85 | + |
| 86 | + |
| 87 | +@tvm.script.ir_module |
| 88 | +class Conv2dCuda1: |
| 89 | + @T.prim_func |
| 90 | + def main(a: T.handle, b: T.handle) -> None: |
| 91 | + # function attr dict |
| 92 | + T.func_attr({"global_symbol": "main", "T.noalias": True}) |
| 93 | + # var definition |
| 94 | + threadIdx_x = T.env_thread("threadIdx.x") |
| 95 | + threadIdx_y = T.env_thread("threadIdx.y") |
| 96 | + blockIdx_x = T.env_thread("blockIdx.x") |
| 97 | + blockIdx_y = T.env_thread("blockIdx.y") |
| 98 | + blockIdx_z = T.env_thread("blockIdx.z") |
| 99 | + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") |
| 100 | + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") |
| 101 | + # body |
| 102 | + T.launch_thread(blockIdx_z, 196) |
| 103 | + B_local = T.allocate([6400000], "float32", "local") |
| 104 | + Apad_shared = T.allocate([512], "float32", "shared") |
| 105 | + Apad_shared_local = T.allocate([8], "float32", "local") |
| 106 | + T.launch_thread(blockIdx_y, 8) |
| 107 | + T.launch_thread(blockIdx_x, 4) |
| 108 | + T.launch_thread(threadIdx_y, 8) |
| 109 | + T.launch_thread(threadIdx_x, 8) |
| 110 | + for ff_c_init, nn_c_init in T.grid(8, 8): |
| 111 | + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) |
| 112 | + for rc_outer, ry, rx in T.grid(32, 3, 3): |
| 113 | + for ax3_inner_outer in T.serial(0, 2): |
| 114 | + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) |
| 115 | + for rc_inner in T.serial(0, 8): |
| 116 | + for ax3 in T.serial(0, 8): |
| 117 | + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) |
| 118 | + for ff_c, nn_c in T.grid(8, 8): |
| 119 | + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) |
| 120 | + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): |
| 121 | + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on |
| 122 | + |
| 123 | + |
| 124 | +@tvm.script.ir_module |
| 125 | +class Conv2dCuda2: |
| 126 | + @T.prim_func |
| 127 | + def main(a: T.handle, b: T.handle) -> None: |
| 128 | + # function attr dict |
| 129 | + T.func_attr({"global_symbol": "main", "T.noalias": True}) |
| 130 | + # var definition |
| 131 | + threadIdx_x = T.env_thread("threadIdx.x") |
| 132 | + threadIdx_y = T.env_thread("threadIdx.y") |
| 133 | + blockIdx_x = T.env_thread("blockIdx.x") |
| 134 | + blockIdx_y = T.env_thread("blockIdx.y") |
| 135 | + blockIdx_z = T.env_thread("blockIdx.z") |
| 136 | + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") |
| 137 | + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") |
| 138 | + # body |
| 139 | + T.launch_thread(blockIdx_z, 196) |
| 140 | + B_local = T.allocate([64], "float32", "local") |
| 141 | + Apad_shared = T.allocate([512000], "float32", "shared") |
| 142 | + Apad_shared_local = T.allocate([8], "float32", "local") |
| 143 | + T.launch_thread(blockIdx_y, 8) |
| 144 | + T.launch_thread(blockIdx_x, 4) |
| 145 | + T.launch_thread(threadIdx_y, 8) |
| 146 | + T.launch_thread(threadIdx_x, 8) |
| 147 | + for ff_c_init, nn_c_init in T.grid(8, 8): |
| 148 | + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) |
| 149 | + for rc_outer, ry, rx in T.grid(32, 3, 3): |
| 150 | + for ax3_inner_outer in T.serial(0, 2): |
| 151 | + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) |
| 152 | + for rc_inner in T.serial(0, 8): |
| 153 | + for ax3 in T.serial(0, 8): |
| 154 | + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) |
| 155 | + for ff_c, nn_c in T.grid(8, 8): |
| 156 | + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) |
| 157 | + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): |
| 158 | + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on |
| 159 | + |
| 160 | + |
| 161 | +@tvm.script.ir_module |
| 162 | +class Conv2dCuda3: |
| 163 | + @T.prim_func |
| 164 | + def main(a: T.handle, b: T.handle) -> None: |
| 165 | + # function attr dict |
| 166 | + T.func_attr({"global_symbol": "main", "T.noalias": True}) |
| 167 | + # var definition |
| 168 | + threadIdx_x = T.env_thread("threadIdx.x") |
| 169 | + threadIdx_y = T.env_thread("threadIdx.y") |
| 170 | + blockIdx_x = T.env_thread("blockIdx.x") |
| 171 | + blockIdx_y = T.env_thread("blockIdx.y") |
| 172 | + blockIdx_z = T.env_thread("blockIdx.z") |
| 173 | + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") |
| 174 | + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") |
| 175 | + # body |
| 176 | + T.launch_thread(blockIdx_z, 196) |
| 177 | + B_local = T.allocate([64], "float32", "local") |
| 178 | + Apad_shared = T.allocate([512], "float32", "shared") |
| 179 | + Apad_shared_local = T.allocate([8], "float32", "local") |
| 180 | + T.launch_thread(blockIdx_y, 8) |
| 181 | + T.launch_thread(blockIdx_x, 4) |
| 182 | + T.launch_thread(threadIdx_y, 8) |
| 183 | + T.launch_thread(threadIdx_x, 800000) |
| 184 | + for ff_c_init, nn_c_init in T.grid(8, 8): |
| 185 | + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) |
| 186 | + for rc_outer, ry, rx in T.grid(32, 3, 3): |
| 187 | + for ax3_inner_outer in T.serial(0, 2): |
| 188 | + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) |
| 189 | + for rc_inner in T.serial(0, 8): |
| 190 | + for ax3 in T.serial(0, 8): |
| 191 | + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) |
| 192 | + for ff_c, nn_c in T.grid(8, 8): |
| 193 | + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) |
| 194 | + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): |
| 195 | + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on |
| 196 | + |
| 197 | + |
| 198 | +# fmt: on |
| 199 | +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant |
| 200 | + |
| 201 | + |
| 202 | +def test_postproc_verify_gpu_0(): |
| 203 | + mod = Conv2dCuda0 |
| 204 | + ctx = _create_context(mod, target=_target()) |
| 205 | + sch = tir.Schedule(mod, debug_mask="all") |
| 206 | + assert ctx.postprocs[0].apply(sch) |
| 207 | + |
| 208 | + |
| 209 | +def test_postproc_verify_gpu_1(): |
| 210 | + mod = Conv2dCuda1 |
| 211 | + ctx = _create_context(mod, target=_target()) |
| 212 | + sch = tir.Schedule(mod, debug_mask="all") |
| 213 | + assert not ctx.postprocs[0].apply(sch) |
| 214 | + |
| 215 | + |
| 216 | +def test_postproc_verify_gpu_2(): |
| 217 | + mod = Conv2dCuda2 |
| 218 | + ctx = _create_context(mod, target=_target()) |
| 219 | + sch = tir.Schedule(mod, debug_mask="all") |
| 220 | + assert not ctx.postprocs[0].apply(sch) |
| 221 | + |
| 222 | + |
| 223 | +def test_postproc_verify_gpu_3(): |
| 224 | + mod = Conv2dCuda3 |
| 225 | + ctx = _create_context(mod, target=_target()) |
| 226 | + sch = tir.Schedule(mod, debug_mask="all") |
| 227 | + assert not ctx.postprocs[0].apply(sch) |
| 228 | + |
| 229 | + |
| 230 | +if __name__ == "__main__": |
| 231 | + sys.exit(pytest.main([__file__] + sys.argv[1:])) |
0 commit comments