Skip to content

Commit a05d81f

Browse files
committed
add test
1 parent d47d90f commit a05d81f

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

python/tvm/topi/arm_cpu/conv2d_alter_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
355355
out_dtype,
356356
False,
357357
data_layout,
358-
int32_lanes=32
358+
int32_lanes=32,
359359
)
360360

361361
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)

tests/python/unittest/test_meta_schedule_integration.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# under the License.
1717
import sys
1818
from typing import List
19+
import numpy as np
1920

2021
import pytest
2122
import tvm
23+
from tvm import relay
2224
from tvm import meta_schedule as ms
2325
from tvm.ir.module import IRModule
2426
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
@@ -149,5 +151,49 @@ def extract_task_qbert():
149151
assert "vnni" in annotations["schedule_rule"]
150152

151153

154+
def extract_task_arm_conv2d_nchwc():
155+
data_shape = (1, 32, 128, 128)
156+
weight_shape = (32, 32, 1, 1)
157+
bias_shape = (weight_shape[0],)
158+
padding = (1, 1)
159+
160+
data = relay.var("data", shape=data_shape, dtype="int8")
161+
weight = relay.var("weight", shape=weight_shape, dtype="int8")
162+
bias = relay.var("bias", shape=bias_shape, dtype="int32")
163+
conv2d = relay.nn.conv2d(
164+
data=data,
165+
weight=weight,
166+
kernel_size=weight_shape[2:],
167+
channels=weight_shape[0],
168+
padding=padding,
169+
strides=(1, 1),
170+
out_dtype="int32",
171+
)
172+
bias_add = relay.nn.bias_add(conv2d, bias)
173+
relay_mod = tvm.IRModule.from_expr(bias_add)
174+
175+
weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
176+
bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32")
177+
178+
params = {"weight": weight_np, "bias": bias_np}
179+
180+
target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon"
181+
extracted_tasks = extract_task_from_relay(relay_mod, target, params)
182+
tune_tasks = list(
183+
filter(
184+
lambda task: "conv2d" in task.task_name,
185+
extracted_tasks,
186+
)
187+
)
188+
189+
assert len(tune_tasks) == 1
190+
191+
relay_func = list(tune_tasks[0].mod.functions.values())[0]
192+
out_type = relay_func.body.checked_type
193+
194+
# Check that the output is in NCHWc layout
195+
assert list(out_type.shape) == [1, 1, 130, 130, 32]
196+
197+
152198
if __name__ == "__main__":
153199
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)