|
16 | 16 | # under the License. |
17 | 17 | import sys |
18 | 18 | from typing import List |
| 19 | +import numpy as np |
19 | 20 |
|
20 | 21 | import pytest |
21 | 22 | import tvm |
| 23 | +from tvm import relay |
22 | 24 | from tvm import meta_schedule as ms |
23 | 25 | from tvm.ir.module import IRModule |
24 | 26 | from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload |
@@ -149,5 +151,49 @@ def extract_task_qbert(): |
149 | 151 | assert "vnni" in annotations["schedule_rule"] |
150 | 152 |
|
151 | 153 |
|
| 154 | +def extract_task_arm_conv2d_nchwc(): |
| 155 | + data_shape = (1, 32, 128, 128) |
| 156 | + weight_shape = (32, 32, 1, 1) |
| 157 | + bias_shape = (weight_shape[0],) |
| 158 | + padding = (1, 1) |
| 159 | + |
| 160 | + data = relay.var("data", shape=data_shape, dtype="int8") |
| 161 | + weight = relay.var("weight", shape=weight_shape, dtype="int8") |
| 162 | + bias = relay.var("bias", shape=bias_shape, dtype="int32") |
| 163 | + conv2d = relay.nn.conv2d( |
| 164 | + data=data, |
| 165 | + weight=weight, |
| 166 | + kernel_size=weight_shape[2:], |
| 167 | + channels=weight_shape[0], |
| 168 | + padding=padding, |
| 169 | + strides=(1, 1), |
| 170 | + out_dtype="int32", |
| 171 | + ) |
| 172 | + bias_add = relay.nn.bias_add(conv2d, bias) |
| 173 | + relay_mod = tvm.IRModule.from_expr(bias_add) |
| 174 | + |
| 175 | + weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8") |
| 176 | + bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32") |
| 177 | + |
| 178 | + params = {"weight": weight_np, "bias": bias_np} |
| 179 | + |
| 180 | + target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon" |
| 181 | + extracted_tasks = extract_task_from_relay(relay_mod, target, params) |
| 182 | + tune_tasks = list( |
| 183 | + filter( |
| 184 | + lambda task: "conv2d" in task.task_name, |
| 185 | + extracted_tasks, |
| 186 | + ) |
| 187 | + ) |
| 188 | + |
| 189 | + assert len(tune_tasks) == 1 |
| 190 | + |
| 191 | + relay_func = list(tune_tasks[0].mod.functions.values())[0] |
| 192 | + out_type = relay_func.body.checked_type |
| 193 | + |
| 194 | + # Check that the output is in NCHWc layout |
| 195 | + assert list(out_type.shape) == [1, 1, 130, 130, 32] |
| 196 | + |
| 197 | + |
152 | 198 | if __name__ == "__main__": |
153 | 199 | sys.exit(pytest.main([__file__] + sys.argv[1:])) |
0 commit comments