Skip to content

Commit 23383bf

Browse files
mehrdadhSergey Shtin
authored andcommitted
[Hexagon] Add test for registered schedules (apache#11016)
* add hexagon schedule tests * moved tests to sub-directories
1 parent 2b0d817 commit 23383bf

17 files changed

+1665
-31
lines changed

python/tvm/contrib/hexagon/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
remote_kw: dict,
5858
session_name: str = "hexagon-rpc",
5959
remote_stack_size_bytes: int = 256 * 1024, # Min size for main thread in QuRT/sim
60-
rpc_receive_buffer_size_bytes: int = 2 * 1024 * 1024,
60+
rpc_receive_buffer_size_bytes: int = 5 * 1024 * 1024, # Size for passing hexagon tests
6161
):
6262
self._launcher = launcher
6363
self._session_name: str = session_name
Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env bash
21
# Licensed to the Apache Software Foundation (ASF) under one
32
# or more contributor license agreements. See the NOTICE file
43
# distributed with this work for additional information
@@ -16,25 +15,4 @@
1615
# specific language governing permissions and limitations
1716
# under the License.
1817

19-
set -e
20-
set -u
21-
22-
source tests/scripts/setup-pytest-env.sh
23-
24-
make cython3
25-
26-
export TVM_TRACKER_PORT=9190
27-
export TVM_TRACKER_HOST=0.0.0.0
28-
env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" &
29-
TRACKER_PID=$!
30-
sleep 5 # Wait for tracker to bind
31-
32-
# Temporary workaround for symbol visibility
33-
export HEXAGON_SHARED_LINK_FLAGS="-Lbuild/hexagon_api_output -lhexagon_rpc_sim"
34-
35-
# HEXAGON_TOOLCHAIN is already set
36-
export HEXAGON_SDK_ROOT=${HEXAGON_SDK_PATH}
37-
export ANDROID_SERIAL_NUMBER=simulator
38-
run_pytest ctypes python-contrib-hexagon-simulator tests/python/contrib/test_hexagon
39-
40-
kill ${TRACKER_PID}
18+
""" Testing infrastructure for Hexagon/TOPI/Conv2d """

tests/python/contrib/test_hexagon/test_conv2d_blocked.md renamed to tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md

File renamed without changes.

tests/python/contrib/test_hexagon/test_conv2d_blocked.py renamed to tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tvm import topi
2424
from tvm.topi import testing
2525

26-
from .infrastructure import (
26+
from ..infrastructure import (
2727
build_and_run,
2828
conv2d_compute,
2929
conv2d_verify,

tests/python/contrib/test_hexagon/test_conv2d_conv2d.md renamed to tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md

File renamed without changes.

tests/python/contrib/test_hexagon/test_conv2d_conv2d.py renamed to tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tvm import topi
2424
from tvm.topi import testing
2525

26-
from .infrastructure import (
26+
from ..infrastructure import (
2727
build_and_run,
2828
conv2d_compute,
2929
conv2d_verify,

tests/python/contrib/test_hexagon/test_2d_physical_buffers.py

100755100644
File mode changed.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
""" Testing infrastructure for Hexagon/TOPI """
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test code for matmul"""
18+
import numpy as np
19+
import pytest
20+
import sys
21+
22+
import tvm
23+
from tvm import topi
24+
from tvm import te
25+
import tvm.topi.testing
26+
from tvm.topi.utils import get_const_tuple
27+
28+
from ..conftest import requires_hexagon_toolchain
29+
30+
dtype = tvm.testing.parameter(
31+
"float32",
32+
"float16",
33+
)
34+
35+
36+
class TestMatMulFloat:
37+
x_batch, y_batch, M, N, K = tvm.testing.parameters(
38+
(1, 1, 16, 16, 32),
39+
(5, 5, 16, 16, 32),
40+
(5, 5, 16, 20, 32),
41+
(30, 30, 16, 20, 32),
42+
# Test batch broadcasting.
43+
(1, 5, 16, 16, 32),
44+
(5, 1, 16, 16, 32),
45+
)
46+
47+
# TODO(mehrdadh): add dynamic testing
48+
@requires_hexagon_toolchain
49+
def test_batch_matmul(self, hexagon_session, x_batch, y_batch, M, N, K, dtype):
50+
if dtype == "float16":
51+
pytest.xfail("float16 is not supported.")
52+
53+
x = te.placeholder((x_batch, M, K), name="x")
54+
y = te.placeholder((y_batch, N, K), name="y")
55+
56+
def get_ref_data():
57+
a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
58+
b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
59+
c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
60+
return (a_np, b_np, c_np)
61+
62+
# get the test data
63+
a_np, b_np, c_np = get_ref_data()
64+
65+
target_hexagon = tvm.target.hexagon("v68")
66+
with tvm.target.Target(target_hexagon):
67+
fcompute = topi.nn.batch_matmul
68+
fschedule = topi.hexagon.schedule_batch_matmul
69+
out = fcompute(x, y)
70+
s = fschedule([out])
71+
out_shape = out.shape
72+
73+
func = tvm.build(
74+
s,
75+
[x, y, out],
76+
tvm.target.Target(target_hexagon, host=target_hexagon),
77+
name="batch_matmul",
78+
)
79+
mod = hexagon_session.load_module(func)
80+
81+
dev = hexagon_session.device
82+
a = tvm.nd.array(a_np, dev)
83+
b = tvm.nd.array(b_np, dev)
84+
c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype), dev)
85+
mod["batch_matmul"](a, b, c)
86+
87+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
88+
89+
90+
class TestMatMulInt8:
91+
x_batch, y_batch, M, N, K = tvm.testing.parameters(
92+
(1, 1, 2, 3, 1),
93+
(1, 1, 16, 24, 32),
94+
(5, 5, 24, 16, 32),
95+
(30, 30, 16, 20, 32),
96+
(1, 5, 16, 16, 32),
97+
(5, 1, 16, 16, 32),
98+
)
99+
100+
@requires_hexagon_toolchain
101+
def test_batch_matmul_int8(self, hexagon_session, x_batch, y_batch, M, N, K):
102+
dtype = "int8"
103+
out_dtype = "int8"
104+
assert x_batch == y_batch or x_batch == 1 or y_batch == 1
105+
x = te.placeholder((x_batch, M, K), name="x", dtype=dtype)
106+
y = te.placeholder((y_batch, N, K), name="y", dtype=dtype)
107+
108+
def get_ref_data():
109+
a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype)
110+
b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype)
111+
c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype=out_dtype)
112+
return (a_np, b_np, c_np)
113+
114+
# get the test data
115+
a_np, b_np, c_np = get_ref_data()
116+
117+
target_hexagon = tvm.target.hexagon("v68")
118+
with tvm.target.Target(target_hexagon):
119+
fcompute = topi.nn.batch_matmul
120+
fschedule = topi.hexagon.schedule_batch_matmul
121+
out = fcompute(x, y)
122+
s = fschedule([out])
123+
124+
func = tvm.build(
125+
s,
126+
[x, y, out],
127+
tvm.target.Target(target_hexagon, host=target_hexagon),
128+
name="batch_matmul_int8",
129+
)
130+
mod = hexagon_session.load_module(func)
131+
132+
dev = hexagon_session.device
133+
a = tvm.nd.array(a_np, dev)
134+
b = tvm.nd.array(b_np, dev)
135+
c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev)
136+
mod["batch_matmul_int8"](a, b, c)
137+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
138+
139+
140+
if __name__ == "__main__":
141+
sys.exit(pytest.main(sys.argv))

tests/python/contrib/test_hexagon/test_cache_read_write.py renamed to tests/python/contrib/test_hexagon/topi/test_cache_read_write.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020

2121
import tvm.testing
2222
from tvm import te
23-
from tvm.contrib import utils
24-
from tvm.contrib.hexagon.build import HexagonLauncher
25-
import tvm.contrib.hexagon as hexagon
2623

27-
from .conftest import requires_hexagon_toolchain
24+
from ..conftest import requires_hexagon_toolchain
2825

2926

3027
def intrin_mem_copy(shape, dtype, dst_scope, src_scope):

0 commit comments

Comments
 (0)