|
21 | 21 | import pytest |
22 | 22 | import numpy as np |
23 | 23 | import logging |
| 24 | +#import time |
24 | 25 |
|
25 | 26 | import tvm.testing |
26 | 27 | from tvm import te |
|
40 | 41 | # triggering TIME_WAIT state on the server socket. This prevents another |
41 | 42 | # server to bind to the same port until the wait time elapses. |
42 | 43 |
|
| 44 | +@requires_hexagon_toolchain |
| 45 | +def test_add_hvx(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket): |
| 46 | + #for dtype in ['float16', 'int8', 'float32',]: |
| 47 | + for dtype in ['int8', 'float32',]: |
| 48 | + #for dtype in ['float32',]: |
| 49 | + for sched_type in [1,2,]: |
| 50 | + for mem_scope in [None, "global.vtcm"]: |
| 51 | + version_name = 'dtype:{}-schedtype:{}-memscope:{}'.format(dtype, str(sched_type), str(mem_scope)) |
| 52 | + |
| 53 | + #Eliminate potentially problematic characters... |
| 54 | + #version_name = version_name.replace('.', '_') |
| 55 | + #version_name = version_name.replace(':', '_') |
| 56 | + |
| 57 | + print("CONFIGURATION: {}".format(version_name)) |
| 58 | + #time.sleep(5) |
| 59 | + |
| 60 | + dtype_bits = tvm._ffi.runtime_ctypes.DataType(dtype).bits |
| 61 | + |
| 62 | + HVX_VECTOR_BYTES=128 |
| 63 | + |
| 64 | + assert dtype_bits % 8 == 0 |
| 65 | + dtype_bytes = dtype_bits // 8 |
| 66 | + |
| 67 | + elem_per_hvx_vector = HVX_VECTOR_BYTES // dtype_bytes |
| 68 | + |
| 69 | + # Note! We're providing the complete input tensor shapes now, |
| 70 | + # whereas the original code only reveals the exact shape when |
| 71 | + # about to call the kernel. |
| 72 | + |
| 73 | + shape = [4, elem_per_hvx_vector,] |
| 74 | + |
| 75 | + A = tvm.te.placeholder(shape, dtype=dtype) |
| 76 | + B = tvm.te.placeholder(shape, dtype=dtype) |
| 77 | + C = tvm.te.compute(A.shape, lambda i,j: A[i,j] + B[i,j], name="C") |
| 78 | + |
| 79 | + #TODO: see if anyone cares that this segfaults |
| 80 | + #foozle = tvm.lower(sched1) |
| 81 | + |
| 82 | + sched = tvm.te.create_schedule(C.op) |
| 83 | + |
| 84 | + if sched_type == 1: |
| 85 | + pass |
| 86 | + elif sched_type == 2: |
| 87 | + sched[C].vectorize(C.op.axis[1]) |
| 88 | + else: |
| 89 | + raise Exception("Unknown schedule type") |
| 90 | + |
| 91 | + foozle = tvm.lower(sched, [A,B,C], "foo") |
| 92 | + |
| 93 | + report_path = "/tmp/cconvey-report-dtype-{}-sched{}.txt".format(dtype, sched_type) |
| 94 | + with open(report_path, 'w') as f: |
| 95 | + f.write("LOWERED IR MODULE:\n") |
| 96 | + f.write(str(foozle)) |
| 97 | + f.write('\n') |
| 98 | + |
| 99 | + target_hexagon = tvm.target.hexagon("v68", link_params=True) |
| 100 | + func = tvm.build( |
| 101 | + sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add_hvx" |
| 102 | + ) |
| 103 | + |
| 104 | + temp = utils.tempdir() |
| 105 | + if True: |
| 106 | + dso_binary = "test_binary.so".format(version_name) |
| 107 | + dso_binary_path = temp.relpath(dso_binary) |
| 108 | + else: |
| 109 | + dso_binary = "test_binary-{}.so".format(version_name) |
| 110 | + dso_binary_path = "/tmp/cconvey-{}.so".format(version_name) |
| 111 | + func.save(dso_binary_path) |
| 112 | + |
| 113 | + print("SAVED BINARY TO HOST PATH: {}".format(dso_binary_path)) |
| 114 | + |
| 115 | + #import pdb; pdb.set_trace() |
| 116 | + |
| 117 | + if not android_serial_number: |
| 118 | + pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") |
| 119 | + |
| 120 | + rpc_info = { |
| 121 | + "rpc_tracker_host": tvm_tracker_host, |
| 122 | + "rpc_tracker_port": tvm_tracker_port, |
| 123 | + "rpc_server_port": RPC_SERVER_PORT + 0, # See note at the beginning of the file |
| 124 | + "adb_server_socket": adb_server_socket, |
| 125 | + } |
| 126 | + launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) |
| 127 | + launcher.upload(dso_binary_path, dso_binary) |
| 128 | + launcher.start_server() |
| 129 | + |
| 130 | + with launcher.start_session() as sess: |
| 131 | + mod = launcher.load_module(dso_binary, sess) |
| 132 | + |
| 133 | + # TODO: I think I was hitting an error because I tried to write into |
| 134 | + # "B_data.numpy()[...]"; i.e. going through the TVM wrapper object |
| 135 | + # after it was created. Not sure if this is necessary. |
| 136 | + host_numpy_A_data = np.ndarray(shape, dtype=dtype) |
| 137 | + host_numpy_B_data = np.ndarray(shape, dtype=dtype) |
| 138 | + host_numpy_C_data = np.ndarray(shape, dtype=dtype) |
| 139 | + host_numpy_C_data_expected = np.ndarray(shape, dtype=dtype) |
| 140 | + |
| 141 | + def intended_val_A(i,j): |
| 142 | + return i + j |
| 143 | + |
| 144 | + def intended_val_B(i,j): |
| 145 | + return (i+1) * (j+1) |
| 146 | + |
| 147 | + for i in range(shape[0]): |
| 148 | + for j in range(shape[1]): |
| 149 | + host_numpy_A_data[i,j] = intended_val_A(i,j) |
| 150 | + host_numpy_B_data[i,j] = intended_val_B(i,j) |
| 151 | + host_numpy_C_data_expected[i,j] = intended_val_A(i,j) + intended_val_B(i,j) |
| 152 | + |
| 153 | + #A_data = tvm.nd.array(host_numpy_A_data, device=sess.device, mem_scope=mem_scope) |
| 154 | + #B_data = tvm.nd.array(host_numpy_B_data, device=sess.device, mem_scope=mem_scope) |
| 155 | + #C_data = tvm.nd.array(np.ndarray(shape, dtype=dtype), device=sess.device, mem_scope=mem_scope) |
| 156 | + |
| 157 | + A_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) |
| 158 | + A_data.copyfrom(host_numpy_A_data) |
| 159 | + |
| 160 | + B_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) |
| 161 | + B_data.copyfrom(host_numpy_B_data) |
| 162 | + |
| 163 | + C_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) |
| 164 | + #C_data.copyfrom(host_numpy_C_data) |
| 165 | + |
| 166 | + #import pdb; pdb.set_trace() |
| 167 | + |
| 168 | + mod["add_hvx"](A_data, B_data, C_data) |
| 169 | + |
| 170 | + result = C_data.numpy() |
| 171 | + assert (result == host_numpy_C_data_expected).all() |
| 172 | + |
| 173 | + launcher.stop_server() |
| 174 | + |
43 | 175 |
|
44 | 176 | @requires_hexagon_toolchain |
45 | 177 | def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket): |
@@ -69,6 +201,7 @@ def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_serv |
69 | 201 | "adb_server_socket": adb_server_socket, |
70 | 202 | } |
71 | 203 | launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) |
| 204 | + import pdb; pdb.set_trace() |
72 | 205 | launcher.upload(dso_binary_path, dso_binary) |
73 | 206 | launcher.start_server() |
74 | 207 |
|
|
0 commit comments