|
16 | 16 | # under the License. |
17 | 17 |
|
18 | 18 | import os |
| 19 | +import os.path |
19 | 20 | import pathlib |
20 | 21 | import sys |
21 | 22 | import pytest |
22 | 23 | import numpy as np |
23 | 24 | import logging |
| 25 | +import tempfile |
24 | 26 |
|
25 | 27 | import tvm.testing |
26 | 28 | from tvm import te |
|
40 | 42 | # triggering TIME_WAIT state on the server socket. This prevents another |
41 | 43 | # server to bind to the same port until the wait time elapses. |
42 | 44 |
|
| 45 | +@requires_hexagon_toolchain |
| 46 | +def test_add_hvx(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket): |
| 47 | + """ |
| 48 | + Starting with an elementwise-add computation, try various schedules / optimizations to |
| 49 | + see the impact they have on performance. |
| 50 | +
|
| 51 | + The main motivation for this test is to explore the relationship between these |
| 52 | + schedules / optimizations vs. how effectively the primfunc uses the Hexagon's |
| 53 | + HVX units. |
| 54 | + """ |
| 55 | + |
| 56 | + host_output_dir = tempfile.mkdtemp() |
| 57 | + |
| 58 | + print("-"*80) |
| 59 | + print("OUTPUT DIRECTORY: {}".format(host_output_dir)) |
| 60 | + print("-"*80) |
| 61 | + print() |
| 62 | + |
| 63 | + class benchmark_results_collection: |
| 64 | + def __init__(self): |
| 65 | + # We'll store the results in corresponding arrays, for simplicity. |
| 66 | + self.dtypes = [] |
| 67 | + self.sched_types = [] |
| 68 | + self.mem_scopes = [] |
| 69 | + self.nums_vecs_per_tensor = [] |
| 70 | + self.benchmark_results = [] |
| 71 | + self.failure_texts = [] |
| 72 | + |
| 73 | + def record_success(self, dtype, sched_type, mem_scope, num_vecs_per_tensor, benchmark_result): |
| 74 | + self.dtypes.append(dtype) |
| 75 | + self.sched_types.append(sched_type) |
| 76 | + self.mem_scopes.append(mem_scope) |
| 77 | + self.nums_vecs_per_tensor.append(num_vecs_per_tensor) |
| 78 | + self.benchmark_results.append(benchmark_result) |
| 79 | + self.failure_texts.append(None) |
| 80 | + |
| 81 | + def record_failure(self, dtype, sched_type, mem_scope, num_vecs_per_tensor, outcome): |
| 82 | + self.dtypes.append(dtype) |
| 83 | + self.sched_types.append(sched_type) |
| 84 | + self.mem_scopes.append(mem_scope) |
| 85 | + self.nums_vecs_per_tensor.append(num_vecs_per_tensor) |
| 86 | + self.benchmark_results.append(None) |
| 87 | + self.failure_texts.append(outcome) |
| 88 | + |
| 89 | + |
| 90 | + def dump(self, f): |
| 91 | + delim = '\t' |
| 92 | + |
| 93 | + f.write(f'dtype') |
| 94 | + |
| 95 | + f.write(delim) |
| 96 | + f.write(f'sched_type') |
| 97 | + |
| 98 | + f.write(delim) |
| 99 | + f.write(f'mem_scope') |
| 100 | + |
| 101 | + f.write(delim) |
| 102 | + f.write(f'num_vecs_per_tensor') |
| 103 | + |
| 104 | + f.write(delim) |
| 105 | + f.write(f'median(µsec)') |
| 106 | + |
| 107 | + f.write(delim) |
| 108 | + f.write(f'min(µsec)') |
| 109 | + |
| 110 | + f.write(delim) |
| 111 | + f.write(f'max(µsec)') |
| 112 | + |
| 113 | + f.write(delim) |
| 114 | + f.write(f'comment') |
| 115 | + |
| 116 | + f.write('\n') |
| 117 | + |
| 118 | + for i in range(len(self.dtypes)): |
| 119 | + f.write('{}'.format(self.dtypes[i])) |
| 120 | + |
| 121 | + f.write(delim) |
| 122 | + f.write('{}'.format(self.sched_types[i])) |
| 123 | + |
| 124 | + f.write(delim) |
| 125 | + f.write('{}'.format(self.mem_scopes[i])) |
| 126 | + |
| 127 | + f.write(delim) |
| 128 | + f.write('{}'.format(self.nums_vecs_per_tensor[i])) |
| 129 | + |
| 130 | + r = self.benchmark_results[i] |
| 131 | + ft = self.failure_texts[i] |
| 132 | + |
| 133 | + if r is None: |
| 134 | + f.write(delim) |
| 135 | + f.write(delim) |
| 136 | + f.write(delim) |
| 137 | + else: |
| 138 | + median_usec = r.median * 1000000 |
| 139 | + min_usec = r.min * 1000000 |
| 140 | + max_usec = r.max * 1000000 |
| 141 | + |
| 142 | + f.write(delim) |
| 143 | + f.write(f'{median_usec:.3}') |
| 144 | + |
| 145 | + f.write(delim) |
| 146 | + f.write(f'{min_usec:.3}') |
| 147 | + |
| 148 | + f.write(delim) |
| 149 | + f.write(f'{max_usec:.3}') |
| 150 | + |
| 151 | + if ft is None: |
| 152 | + f.write(delim) |
| 153 | + f.write('OK') |
| 154 | + else: |
| 155 | + f.write(delim) |
| 156 | + f.write(f'FAILURE: {ft}') |
| 157 | + |
| 158 | + f.write('\n') |
| 159 | + |
| 160 | + br = benchmark_results_collection() |
| 161 | + |
| 162 | + for dtype in ['int8',]: # Hexagon v68 allows more dtypes, but we're sticking with v68 for now. |
| 163 | + for sched_type in [1,2,]: |
| 164 | + for mem_scope in [None, "global.vtcm"]: |
| 165 | + |
| 166 | + # These numbers are fairly arbitrary, but they're meant to stress memory/caches to |
| 167 | + # various extents. |
| 168 | + for num_vectors_per_tensor in [1,16,64,512,2048]: |
| 169 | + |
| 170 | + version_name = 'dtype:{}-schedtype:{}-memscope{}-numvecs:{}'.format(dtype, str(sched_type), str(mem_scope), num_vectors_per_tensor) |
| 171 | + print("CONFIGURATION: {}".format(version_name)) |
| 172 | + |
| 173 | + # This is a fixed detail of the v68 architecture. |
| 174 | + HVX_VECTOR_BYTES=128 |
| 175 | + |
| 176 | + dtype_bits = tvm._ffi.runtime_ctypes.DataType(dtype).bits |
| 177 | + assert dtype_bits % 8 == 0 |
| 178 | + dtype_bytes = dtype_bits // 8 |
| 179 | + |
| 180 | + elem_per_hvx_vector = HVX_VECTOR_BYTES // dtype_bytes |
| 181 | + |
| 182 | + # Note! We're providing the complete input tensor shapes now, |
| 183 | + # whereas the original code only reveals the exact shape when |
| 184 | + # about to call the kernel. |
| 185 | + |
| 186 | + shape = [num_vectors_per_tensor, elem_per_hvx_vector,] |
| 187 | + |
| 188 | + A = tvm.te.placeholder(shape, dtype=dtype) |
| 189 | + B = tvm.te.placeholder(shape, dtype=dtype) |
| 190 | + C = tvm.te.compute(A.shape, lambda i,j: A[i,j] + B[i,j], name="C") |
| 191 | + |
| 192 | + sched = tvm.te.create_schedule(C.op) |
| 193 | + |
| 194 | + if sched_type == 1: |
| 195 | + pass |
| 196 | + elif sched_type == 2: |
| 197 | + sched[C].vectorize(C.op.axis[1]) |
| 198 | + else: |
| 199 | + raise Exception("Unknown schedule type") |
| 200 | + |
| 201 | + # This module is only created so humans can inspect its IR. |
| 202 | + module_for_ir_dump = tvm.lower(sched, [A,B,C], "foo") |
| 203 | + |
| 204 | + report_path = os.path.join(host_output_dir, f'{version_name}.txt') |
| 205 | + |
| 206 | + with open(report_path, 'w') as f: |
| 207 | + f.write("LOWERED IR MODULE:\n") |
| 208 | + f.write(str(module_for_ir_dump)) |
| 209 | + f.write('\n') |
| 210 | + |
| 211 | + target_hexagon = tvm.target.hexagon("v68", link_params=True) |
| 212 | + func = tvm.build( |
| 213 | + sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add_hvx" |
| 214 | + ) |
| 215 | + |
| 216 | + host_dso_binary_path = os.path.join(host_output_dir, f'test_binary-{version_name}.so') |
| 217 | + target_dso_binary_filename = 'test_binary.so' |
| 218 | + |
| 219 | + func.save(str(host_dso_binary_path)) |
| 220 | + print("SAVED BINARY TO HOST PATH: {}".format(str(host_dso_binary_path))) |
| 221 | + |
| 222 | + if not android_serial_number: |
| 223 | + pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") |
| 224 | + |
| 225 | + rpc_info = { |
| 226 | + "rpc_tracker_host": tvm_tracker_host, |
| 227 | + "rpc_tracker_port": tvm_tracker_port, |
| 228 | + "rpc_server_port": RPC_SERVER_PORT + 0, # See note at the beginning of the file |
| 229 | + "adb_server_socket": adb_server_socket, |
| 230 | + } |
| 231 | + launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) |
| 232 | + launcher.upload(str(host_dso_binary_path), str(target_dso_binary_filename)) |
| 233 | + launcher.start_server() |
| 234 | + |
| 235 | + try: |
| 236 | + with launcher.start_session() as sess: |
| 237 | + mod = launcher.load_module(target_dso_binary_filename, sess) |
| 238 | + |
| 239 | + host_numpy_A_data = np.ndarray(shape, dtype=dtype) |
| 240 | + host_numpy_B_data = np.ndarray(shape, dtype=dtype) |
| 241 | + host_numpy_C_data = np.ndarray(shape, dtype=dtype) |
| 242 | + host_numpy_C_data_expected = np.ndarray(shape, dtype=dtype) |
| 243 | + |
| 244 | + def intended_val_A(i,j): |
| 245 | + return i + j |
| 246 | + |
| 247 | + def intended_val_B(i,j): |
| 248 | + return (i+1) * (j+1) |
| 249 | + |
| 250 | + for i in range(shape[0]): |
| 251 | + for j in range(shape[1]): |
| 252 | + host_numpy_A_data[i,j] = intended_val_A(i,j) |
| 253 | + host_numpy_B_data[i,j] = intended_val_B(i,j) |
| 254 | + host_numpy_C_data_expected[i,j] = intended_val_A(i,j) + intended_val_B(i,j) |
| 255 | + |
| 256 | + A_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) |
| 257 | + A_data.copyfrom(host_numpy_A_data) |
| 258 | + |
| 259 | + B_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) |
| 260 | + B_data.copyfrom(host_numpy_B_data) |
| 261 | + |
| 262 | + C_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) |
| 263 | + |
| 264 | + timer = mod.time_evaluator("add_hvx", sess.device, number=100, repeat=1, min_repeat_ms=1000) |
| 265 | + timing_result = timer(A_data, B_data, C_data) |
| 266 | + |
| 267 | + print("TIMING RESULT: {}".format(timing_result)) |
| 268 | + |
| 269 | + # Verify that the computation actually happened, and produced the correct result. |
| 270 | + result = C_data.numpy() |
| 271 | + assert (result == host_numpy_C_data_expected).all() |
| 272 | + |
| 273 | + br.record_success(dtype, sched_type, mem_scope, num_vectors_per_tensor, timing_result) |
| 274 | + except: |
| 275 | + br.record_failure(dtype, sched_type, mem_scope, num_vectors_per_tensor, 'failed') |
| 276 | + |
| 277 | + launcher.stop_server() |
| 278 | + br.dump(sys.stdout) |
| 279 | + |
| 280 | + print("-"*80) |
| 281 | + print(f"OUTPUT DIRECTORY: {host_output_dir}") |
| 282 | + print("-"*80) |
| 283 | + print() |
| 284 | + |
| 285 | + tabular_output_filename = os.path.join(host_output_dir, 'benchmark-results.csv') |
| 286 | + with open(tabular_output_filename, 'w') as csv_file: |
| 287 | + br.dump(csv_file) |
| 288 | + print(f'BENCHMARK RESULTS FILE: {tabular_output_filename}') |
| 289 | + |
43 | 290 |
|
44 | 291 | @requires_hexagon_toolchain |
45 | 292 | def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket): |
|
0 commit comments