Skip to content

Commit eb69c2f

Browse files
committed
add hexagon schedule tests
1 parent 0e1a2a2 commit eb69c2f

File tree

9 files changed

+1262
-4
lines changed

9 files changed

+1262
-4
lines changed

python/tvm/script/tir/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def alloc_buffer(
226226
"""
227227
special_stmt - Reads/Writes
228228
"""
229+
229230
@overload
230231
def reads(read_regions: List[BufferSlice]) -> None: ...
231232
@overload
@@ -337,6 +338,7 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr:
337338
"""
338339
Scope handler - Loops
339340
"""
341+
340342
@overload
341343
def serial(
342344
begin: Union[PrimExpr, int],

src/runtime/hexagon/rpc/hexagon/rpc_server.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ extern "C" {
4141
#include "hexagon_rpc.h"
4242

4343
// TODO(mehrdadh): make this configurable.
44-
#define TVM_HEXAGON_RPC_BUFF_SIZE_BYTES 2 * 1024 * 1024
44+
#define TVM_HEXAGON_RPC_BUFF_SIZE_BYTES 5 * 1024 * 1024
4545

4646
// TODO(csulivan,adstraw,kparzysz-quic) This should be set on a TVM-wide basis.
4747
#if defined(__hexagon__)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test code for matmul"""
18+
import numpy as np
19+
import pytest
20+
import sys
21+
22+
import tvm
23+
from tvm import topi
24+
from tvm import te
25+
import tvm.topi.testing
26+
from tvm.topi.utils import get_const_tuple
27+
28+
from .conftest import requires_hexagon_toolchain
29+
30+
dtype = tvm.testing.parameter(
31+
"float32",
32+
"float16",
33+
)
34+
35+
x_batch, y_batch, M, N, K = tvm.testing.parameters(
36+
(1, 1, 16, 16, 32),
37+
(5, 5, 16, 16, 32),
38+
(5, 5, 16, 20, 32),
39+
(30, 30, 16, 20, 32),
40+
# Test batch broadcasting.
41+
(1, 5, 16, 16, 32),
42+
(5, 1, 16, 16, 32),
43+
)
44+
45+
# TODO(mehrdadh): add dynamic testing
46+
@requires_hexagon_toolchain
47+
def test_batch_matmul(hexagon_session, x_batch, y_batch, M, N, K, dtype):
48+
if dtype == "float16":
49+
pytest.xfail("float16 is not supported.")
50+
51+
x = te.placeholder((x_batch, M, K), name="x")
52+
y = te.placeholder((y_batch, N, K), name="y")
53+
54+
def get_ref_data():
55+
a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
56+
b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
57+
c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
58+
return (a_np, b_np, c_np)
59+
60+
# get the test data
61+
a_np, b_np, c_np = get_ref_data()
62+
63+
target_hexagon = tvm.target.hexagon("v68")
64+
with tvm.target.Target(target_hexagon):
65+
fcompute = topi.nn.batch_matmul
66+
fschedule = topi.hexagon.schedule_batch_matmul
67+
out = fcompute(x, y)
68+
s = fschedule([out])
69+
out_shape = out.shape
70+
71+
func = tvm.build(
72+
s, [x, y, out], tvm.target.Target(target_hexagon, host=target_hexagon), name="batch_matmul"
73+
)
74+
mod = hexagon_session.load_module(func)
75+
76+
dev = hexagon_session.device
77+
a = tvm.nd.array(a_np, dev)
78+
b = tvm.nd.array(b_np, dev)
79+
c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype), dev)
80+
mod["batch_matmul"](a, b, c)
81+
82+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
83+
84+
85+
x_batch_1, y_batch_1, M_1, N_1, K_1 = tvm.testing.parameters(
86+
(1, 1, 2, 3, 1),
87+
(1, 1, 16, 24, 32),
88+
(5, 5, 24, 16, 32),
89+
(30, 30, 16, 20, 32),
90+
(1, 5, 16, 16, 32),
91+
(5, 1, 16, 16, 32),
92+
)
93+
94+
95+
@requires_hexagon_toolchain
96+
def test_batch_matmul_int8(hexagon_session, x_batch_1, y_batch_1, M_1, N_1, K_1):
97+
dtype = "int8"
98+
out_dtype = "int8"
99+
assert x_batch_1 == y_batch_1 or x_batch_1 == 1 or y_batch_1 == 1
100+
x = te.placeholder((x_batch_1, M_1, K_1), name="x", dtype=dtype)
101+
y = te.placeholder((y_batch_1, N_1, K_1), name="y", dtype=dtype)
102+
103+
def get_ref_data():
104+
a_np = np.random.randint(low=-128, high=127, size=(x_batch_1, M_1, K_1)).astype(dtype)
105+
b_np = np.random.randint(low=-128, high=127, size=(y_batch_1, N_1, K_1)).astype(dtype)
106+
c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype=out_dtype)
107+
return (a_np, b_np, c_np)
108+
109+
# get the test data
110+
a_np, b_np, c_np = get_ref_data()
111+
112+
target_hexagon = tvm.target.hexagon("v68")
113+
with tvm.target.Target(target_hexagon):
114+
fcompute = topi.nn.batch_matmul
115+
fschedule = topi.hexagon.schedule_batch_matmul
116+
out = fcompute(x, y)
117+
s = fschedule([out])
118+
119+
func = tvm.build(
120+
s,
121+
[x, y, out],
122+
tvm.target.Target(target_hexagon, host=target_hexagon),
123+
name="batch_matmul_int8",
124+
)
125+
mod = hexagon_session.load_module(func)
126+
127+
dev = hexagon_session.device
128+
a = tvm.nd.array(a_np, dev)
129+
b = tvm.nd.array(b_np, dev)
130+
c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev)
131+
mod["batch_matmul_int8"](a, b, c)
132+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
133+
134+
135+
if __name__ == "__main__":
136+
sys.exit(pytest.main(sys.argv))
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test code for dense"""
18+
import numpy as np
19+
import pytest
20+
import sys
21+
22+
import tvm
23+
from tvm import topi
24+
from tvm import te
25+
import tvm.topi.testing
26+
from tvm.topi.utils import get_const_tuple
27+
28+
from .conftest import requires_hexagon_toolchain
29+
30+
random_seed = tvm.testing.parameter(0)
31+
32+
use_bias = tvm.testing.parameter(True, False)
33+
34+
# batch_size more than 8 would break
35+
batch_size = tvm.testing.parameter(1, 2, 8)
36+
37+
in_dim, out_dim = tvm.testing.parameters((1024, 1000))
38+
39+
in_dtype, out_dtype = tvm.testing.parameters(
40+
("float32", "float32"),
41+
("float16", "float32"),
42+
("int8", "int32"),
43+
)
44+
45+
46+
@tvm.testing.fixture(cache_return_value=True)
47+
def dense_ref_data(random_seed, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype):
48+
np.random.seed(random_seed)
49+
50+
if "float" in in_dtype:
51+
a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype)
52+
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype)
53+
c_np = np.random.uniform(size=(out_dim,)).astype(out_dtype)
54+
elif in_dtype == "int8":
55+
a_np = np.random.randint(low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype)
56+
b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype)
57+
c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
58+
else:
59+
raise ValueError("No method to generate test data for data type '{}'".format(in_dtype))
60+
61+
matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))
62+
63+
if use_bias:
64+
matmul += c_np
65+
66+
d_np = np.maximum(matmul, 0)
67+
return (a_np, b_np, c_np, d_np)
68+
69+
70+
@requires_hexagon_toolchain
71+
def test_dense(
72+
hexagon_session, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype, dense_ref_data
73+
):
74+
if in_dtype == "float16":
75+
pytest.xfail("float16 is not supported.")
76+
77+
if "int" in in_dtype:
78+
tol = {"atol": 0, "rtol": 0}
79+
elif in_dtype == "float32":
80+
tol = {"rtol": 1e-5, "atol": 1e-5}
81+
82+
A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype)
83+
B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype)
84+
C = te.placeholder((out_dim,), name="C", dtype=out_dtype)
85+
86+
a_np, b_np, c_np, d_np = dense_ref_data
87+
88+
fcompute = topi.nn.dense
89+
fschedule = topi.hexagon.schedule_dense
90+
91+
target_hexagon = tvm.target.hexagon("v68")
92+
with tvm.target.Target(target_hexagon):
93+
D = fcompute(A, B, C if use_bias else None, out_dtype)
94+
D = topi.nn.relu(D)
95+
s = fschedule([D])
96+
97+
func = tvm.build(
98+
s, [A, B, C, D], tvm.target.Target(target_hexagon, host=target_hexagon), name="dense"
99+
)
100+
mod = hexagon_session.load_module(func)
101+
102+
dev = hexagon_session.device
103+
a = tvm.nd.array(a_np, dev)
104+
b = tvm.nd.array(b_np, dev)
105+
c = tvm.nd.array(c_np, dev)
106+
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev)
107+
mod["dense"](a, b, c, d)
108+
tvm.testing.assert_allclose(d.numpy(), d_np, **tol)
109+
110+
111+
if __name__ == "__main__":
112+
sys.exit(pytest.main(sys.argv))

tests/python/contrib/test_hexagon/test_launcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
# under the License.
1717

1818
import os
19-
import pathlib
2019
import sys
2120
import pytest
2221
import numpy as np
23-
import logging
2422

2523
import tvm.testing
2624
from tvm import te

0 commit comments

Comments
 (0)