Skip to content

Commit 924f438

Browse files
adstrawpfk-beta
authored andcommitted
Lower cache_read and cache_write to Hexagon DMA via tensorize (apache#10365)
* Lower cache_read and cache_write to Hexagon DMA via tensorize * rework test to be compatible with launcher * remove cpu device api mem_copy implementation and test
1 parent ca99420 commit 924f438

File tree

6 files changed

+174
-1
lines changed

6 files changed

+174
-1
lines changed

include/tvm/tir/builtin.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,13 @@ TVM_DLL const Op& texture2d_store();
626626
*/
627627
TVM_DLL const Op& texture2d_load();
628628

629+
/*!
630+
* \brief Copy 1d memory from source to destination
631+
* Same semantics as memcpy(destination, source, size)
632+
* Allows for device specific implementations e.g. direct memory access (DMA)
633+
*/
634+
TVM_DLL const Op& mem_copy();
635+
629636
/*! \brief The kind of structure field info used in intrinsic */
630637
enum TVMStructFieldKind : int {
631638
// array head address

src/runtime/hexagon/hexagon/hexagon_buffer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace tvm {
3737
namespace runtime {
3838
namespace hexagon {
3939

40-
int hexagon_user_dma_1d_sync(void* src, void* dst, uint32_t length);
40+
int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length);
4141

4242
struct Allocation {
4343
Allocation(size_t allocation_nbytes, size_t alignment)

src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ namespace tvm {
4040
namespace runtime {
4141
namespace hexagon {
4242

43+
int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length);
44+
4345
HexagonDeviceAPIv2* HexagonDeviceAPIv2::Global() {
4446
static auto* inst = new HexagonDeviceAPIv2();
4547
return inst;
@@ -149,6 +151,16 @@ void HexagonDeviceAPIv2::CopyDataFromTo(const void* from, size_t from_offset, vo
149151
memcpy(static_cast<char*>(to) + to_offset, static_cast<const char*>(from) + from_offset, size);
150152
}
151153

154+
TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
155+
void* dst = args[0];
156+
void* src = args[1];
157+
int size = args[2];
158+
159+
hexagon_user_dma_1d_sync(dst, src, size);
160+
161+
*rv = static_cast<int32_t>(0);
162+
});
163+
152164
TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) {
153165
DeviceAPI* ptr = HexagonDeviceAPIv2::Global();
154166
*rv = static_cast<void*>(ptr);

src/tir/op/builtin.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
260260
.set_attr<TVectorizable>("TVectorizable", true)
261261
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
262262

263+
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
264+
Integer(CallEffectKind::kOpaque));
265+
263266
} // namespace builtin
264267
} // namespace tir
265268
} // namespace tvm

src/tir/transforms/lower_tvm_builtin.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,26 @@ class BuiltinLower : public StmtExprMutator {
209209
return MakeArray(op);
210210
} else if (op->op.same_as(builtin::tvm_context_id())) {
211211
return make_zero(op->dtype);
212+
} else if (op->op.same_as(builtin::mem_copy())) {
213+
return MakeMemCopy(op);
212214
} else {
213215
return StmtExprMutator::VisitExpr_(op);
214216
}
215217
}
218+
219+
PrimExpr MakeMemCopy(const CallNode* op) {
220+
PrimExpr dst = op->args[0];
221+
PrimExpr src = op->args[1];
222+
PrimExpr size = op->args[2];
223+
224+
std::string fdevapi_prefix =
225+
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
226+
227+
Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
228+
{StringImm(fdevapi_prefix + ".mem_copy"), dst, src, size});
229+
return VisitExpr(call_packed);
230+
}
231+
216232
// call shape
217233
PrimExpr MakeShape(const CallNode* op) {
218234
// if args.size() == 0, it represents a scalar shape ()
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pytest
19+
import numpy as np
20+
21+
import tvm.testing
22+
from tvm import te
23+
from tvm.contrib import utils
24+
from tvm.contrib.hexagon.build import HexagonLauncher
25+
import tvm.contrib.hexagon.hexagon as hexagon
26+
27+
from .conftest import requires_hexagon_toolchain
28+
29+
30+
def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
31+
assert len(shape) == 1
32+
src = te.placeholder(shape=shape, dtype=dtype, name="src")
33+
dst = te.compute(shape, lambda i: src[i], name="dst")
34+
size = shape[0] * np.dtype(dtype).itemsize
35+
36+
src_buffer = tvm.tir.decl_buffer(
37+
shape,
38+
dtype,
39+
scope=src_scope,
40+
offset_factor=1,
41+
)
42+
43+
dst_buffer = tvm.tir.decl_buffer(
44+
shape,
45+
dtype,
46+
scope=dst_scope,
47+
offset_factor=1,
48+
)
49+
50+
def intrin_func(ins, outs):
51+
ib = tvm.tir.ir_builder.create()
52+
53+
_src = ins[0]
54+
_dst = outs[0]
55+
ib.emit(
56+
tvm.tir.call_intrin(
57+
"handle", "tir.mem_copy", _dst.access_ptr("w"), _src.access_ptr("r"), size
58+
)
59+
)
60+
return ib.get()
61+
62+
return te.decl_tensor_intrin(dst.op, intrin_func, binds={src: src_buffer, dst: dst_buffer})
63+
64+
65+
@requires_hexagon_toolchain
66+
def test_cache_read_write(android_serial_number, tvm_tracker_host, tvm_tracker_port):
67+
size = 128
68+
outer_shape = (size,)
69+
factor = 16
70+
inner_shape = (factor,)
71+
dtype = "int8"
72+
73+
x = te.placeholder(shape=outer_shape, dtype=dtype, name="x")
74+
y = te.placeholder(shape=outer_shape, dtype=dtype, name="y")
75+
z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
76+
s = te.create_schedule(z.op)
77+
78+
x_global = s.cache_read(x, "global.vtcm", [z])
79+
y_global = s.cache_read(y, "global.vtcm", [z])
80+
z_global = s.cache_write(z, "global.vtcm")
81+
82+
zouter, zinner = s[z_global].split(z_global.op.axis[0], factor=factor)
83+
84+
s[x_global].compute_at(s[z_global], zouter)
85+
s[y_global].compute_at(s[z_global], zouter)
86+
87+
mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")
88+
89+
(cache_read_x,) = s[x_global].op.axis
90+
s[x_global].tensorize(cache_read_x, mem_copy_read)
91+
92+
(cache_read_y,) = s[y_global].op.axis
93+
s[y_global].tensorize(cache_read_y, mem_copy_read)
94+
95+
mem_copy_write = intrin_mem_copy(outer_shape, dtype, "global", "global.vtcm")
96+
97+
(cache_write_z,) = s[z].op.axis
98+
s[z].tensorize(cache_write_z, mem_copy_write)
99+
100+
print(tvm.lower(s, [x, y, z]))
101+
102+
target_hexagon = tvm.target.hexagon("v68", link_params=True)
103+
func = tvm.build(
104+
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
105+
)
106+
temp = utils.tempdir()
107+
dso_binary = "test_binary.so"
108+
dso_binary_path = temp.relpath(dso_binary)
109+
func.save(dso_binary_path)
110+
111+
if not android_serial_number:
112+
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
113+
114+
launcher = HexagonLauncher(serial_number=android_serial_number)
115+
launcher.android_run_rpc(rpc_tracker_host=tvm_tracker_host, rpc_tracker_port=tvm_tracker_port)
116+
launcher.hexagon_setup()
117+
remote_kw = {
118+
"host": tvm_tracker_host,
119+
"port": tvm_tracker_port,
120+
"priority": 0,
121+
"timeout": 60,
122+
}
123+
launcher.hexagon_session_setup(remote_kw)
124+
launcher.upload(dso_binary_path, dso_binary)
125+
126+
with launcher.session as sess:
127+
mod = launcher.get_module(dso_binary)
128+
xt = tvm.nd.array(np.random.uniform(size=size).astype(x.dtype), device=sess.device)
129+
yt = tvm.nd.array(np.random.uniform(size=size).astype(y.dtype), device=sess.device)
130+
zt = tvm.nd.array(np.random.uniform(size=size).astype(z.dtype), device=sess.device)
131+
mod["dmacpy"](xt, yt, zt)
132+
launcher.close()
133+
134+
ref = xt.numpy() + yt.numpy()
135+
np.testing.assert_equal(zt.numpy(), ref)

0 commit comments

Comments
 (0)