Skip to content

Commit c2b2426

Browse files
author
Christian Convey
committed
WIP on hvx scheduling of TE for element-wise add
1 parent 7d5ef84 commit c2b2426

File tree

3 files changed

+152
-1
lines changed

3 files changed

+152
-1
lines changed

python/tvm/contrib/hexagon/build.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import datetime
2222
import multiprocessing as mp
2323
import os
24+
import os.path
2425
import pathlib
2526
import signal
2627
import stat
@@ -302,6 +303,7 @@ def __init__(
302303
def _copy_to_remote(
303304
self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path]
304305
):
306+
print("ZZZ: remote_path='{}'".format(remote_path))
305307
"""Abstract method implementation. See description in HexagonLauncherRPC."""
306308
subprocess.check_call(
307309
self._adb_device_sub_cmd + ["push", str(local_path), str(remote_path)]
@@ -311,6 +313,16 @@ def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]):
311313
"""Abstract method implementation. See description in HexagonLauncherRPC."""
312314
subprocess.check_call(self._adb_device_sub_cmd + ["shell", "mkdir", "-p", str(remote_path)])
313315

316+
#p = pathlib.Path(remote_path)
317+
318+
#if p != p.root:
319+
# #Create the parent directory using 'make -p ...', for convenience.
320+
# subprocess.check_call(self._adb_device_sub_cmd + ["shell", "mkdir", "-p", str(p.parent)])
321+
322+
## Use 'mkdir' *without* '-p' so that, if it already exists, we fail rather than
323+
## silently causing mayhem.
324+
#subprocess.check_call(self._adb_device_sub_cmd + ["shell", "mkdir", str(p)])
325+
314326
def _copy_binaries(self):
315327
"""Upload Android server binaries."""
316328

python/tvm/rpc/tracker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
import json
5252
from tvm.contrib.popen_pool import PopenWorker
5353

54+
import os
55+
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
56+
5457
try:
5558
from tornado import ioloop
5659
from . import tornado_util
@@ -64,7 +67,7 @@
6467
from .base import RPC_TRACKER_MAGIC, TrackerCode
6568

6669
logger = logging.getLogger("RPCTracker")
67-
70+
logger.setLevel(logging.DEBUG)
6871

6972
class Scheduler(object):
7073
"""Abstract interface of scheduler."""
@@ -231,6 +234,7 @@ def ret_value(self, data):
231234

232235
def call_handler(self, args):
233236
"""Event handler when json request arrives."""
237+
logger.debug('call_handler: args={}'.format(args))
234238
code = args[0]
235239
if code == TrackerCode.PUT:
236240
key = args[1]
@@ -287,10 +291,12 @@ def _cb(value):
287291

288292
def on_close(self):
289293
self._tracker.close(self)
294+
logger.debug('on_close')
290295

291296
def on_error(self, err):
292297
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
293298
self.close()
299+
logger.debug('on_error')
294300

295301

296302
class TrackerServerHandler(object):

tests/python/contrib/test_hexagon/test_launcher.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222
import numpy as np
2323
import logging
24+
#import time
2425

2526
import tvm.testing
2627
from tvm import te
@@ -40,6 +41,137 @@
4041
# triggering TIME_WAIT state on the server socket. This prevents another
4142
# server to bind to the same port until the wait time elapses.
4243

44+
@requires_hexagon_toolchain
45+
def test_add_hvx(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket):
46+
#for dtype in ['float16', 'int8', 'float32',]:
47+
for dtype in ['int8', 'float32',]:
48+
#for dtype in ['float32',]:
49+
for sched_type in [1,2,]:
50+
for mem_scope in [None, "global.vtcm"]:
51+
version_name = 'dtype:{}-schedtype:{}-memscope:{}'.format(dtype, str(sched_type), str(mem_scope))
52+
53+
#Eliminate potentially problematic characters...
54+
#version_name = version_name.replace('.', '_')
55+
#version_name = version_name.replace(':', '_')
56+
57+
print("CONFIGURATION: {}".format(version_name))
58+
#time.sleep(5)
59+
60+
dtype_bits = tvm._ffi.runtime_ctypes.DataType(dtype).bits
61+
62+
HVX_VECTOR_BYTES=128
63+
64+
assert dtype_bits % 8 == 0
65+
dtype_bytes = dtype_bits // 8
66+
67+
elem_per_hvx_vector = HVX_VECTOR_BYTES // dtype_bytes
68+
69+
# Note! We're providing the complete input tensor shapes now,
70+
# whereas the original code only reveals the exact shape when
71+
# about to call the kernel.
72+
73+
shape = [4, elem_per_hvx_vector,]
74+
75+
A = tvm.te.placeholder(shape, dtype=dtype)
76+
B = tvm.te.placeholder(shape, dtype=dtype)
77+
C = tvm.te.compute(A.shape, lambda i,j: A[i,j] + B[i,j], name="C")
78+
79+
#TODO: see if anyone cares that this segfaults
80+
#foozle = tvm.lower(sched1)
81+
82+
sched = tvm.te.create_schedule(C.op)
83+
84+
if sched_type == 1:
85+
pass
86+
elif sched_type == 2:
87+
sched[C].vectorize(C.op.axis[1])
88+
else:
89+
raise Exception("Unknown schedule type")
90+
91+
foozle = tvm.lower(sched, [A,B,C], "foo")
92+
93+
report_path = "/tmp/cconvey-report-dtype-{}-sched{}.txt".format(dtype, sched_type)
94+
with open(report_path, 'w') as f:
95+
f.write("LOWERED IR MODULE:\n")
96+
f.write(str(foozle))
97+
f.write('\n')
98+
99+
target_hexagon = tvm.target.hexagon("v68", link_params=True)
100+
func = tvm.build(
101+
sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add_hvx"
102+
)
103+
104+
temp = utils.tempdir()
105+
if True:
106+
dso_binary = "test_binary.so".format(version_name)
107+
dso_binary_path = temp.relpath(dso_binary)
108+
else:
109+
dso_binary = "test_binary-{}.so".format(version_name)
110+
dso_binary_path = "/tmp/cconvey-{}.so".format(version_name)
111+
func.save(dso_binary_path)
112+
113+
print("SAVED BINARY TO HOST PATH: {}".format(dso_binary_path))
114+
115+
#import pdb; pdb.set_trace()
116+
117+
if not android_serial_number:
118+
pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
119+
120+
rpc_info = {
121+
"rpc_tracker_host": tvm_tracker_host,
122+
"rpc_tracker_port": tvm_tracker_port,
123+
"rpc_server_port": RPC_SERVER_PORT + 0, # See note at the beginning of the file
124+
"adb_server_socket": adb_server_socket,
125+
}
126+
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
127+
launcher.upload(dso_binary_path, dso_binary)
128+
launcher.start_server()
129+
130+
with launcher.start_session() as sess:
131+
mod = launcher.load_module(dso_binary, sess)
132+
133+
# TODO: I think I was hitting an error because I tried to write into
134+
# "B_data.numpy()[...]"; i.e. going through the TVM wrapper object
135+
# after it was created. Not sure if this is necessary.
136+
host_numpy_A_data = np.ndarray(shape, dtype=dtype)
137+
host_numpy_B_data = np.ndarray(shape, dtype=dtype)
138+
host_numpy_C_data = np.ndarray(shape, dtype=dtype)
139+
host_numpy_C_data_expected = np.ndarray(shape, dtype=dtype)
140+
141+
def intended_val_A(i,j):
142+
return i + j
143+
144+
def intended_val_B(i,j):
145+
return (i+1) * (j+1)
146+
147+
for i in range(shape[0]):
148+
for j in range(shape[1]):
149+
host_numpy_A_data[i,j] = intended_val_A(i,j)
150+
host_numpy_B_data[i,j] = intended_val_B(i,j)
151+
host_numpy_C_data_expected[i,j] = intended_val_A(i,j) + intended_val_B(i,j)
152+
153+
#A_data = tvm.nd.array(host_numpy_A_data, device=sess.device, mem_scope=mem_scope)
154+
#B_data = tvm.nd.array(host_numpy_B_data, device=sess.device, mem_scope=mem_scope)
155+
#C_data = tvm.nd.array(np.ndarray(shape, dtype=dtype), device=sess.device, mem_scope=mem_scope)
156+
157+
A_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope)
158+
A_data.copyfrom(host_numpy_A_data)
159+
160+
B_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope)
161+
B_data.copyfrom(host_numpy_B_data)
162+
163+
C_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope)
164+
#C_data.copyfrom(host_numpy_C_data)
165+
166+
#import pdb; pdb.set_trace()
167+
168+
mod["add_hvx"](A_data, B_data, C_data)
169+
170+
result = C_data.numpy()
171+
assert (result == host_numpy_C_data_expected).all()
172+
173+
launcher.stop_server()
174+
43175

44176
@requires_hexagon_toolchain
45177
def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket):
@@ -69,6 +201,7 @@ def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_serv
69201
"adb_server_socket": adb_server_socket,
70202
}
71203
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
204+
import pdb; pdb.set_trace()
72205
launcher.upload(dso_binary_path, dso_binary)
73206
launcher.start_server()
74207

0 commit comments

Comments
 (0)