Skip to content

Commit 57dee39

Browse files
Lunderbergadstraw
authored andcommitted
[Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed layout (apache#10905)
* [Hexagon][LLVM] Enable/test tensorized Hexagon DMA - In the `CodeGenLLVM::CreateIntrinsic` handler for `builtin::address_of()`, pass N-d indices to `CodeGenLLVM::CreateBufferPtr`. The base class implementation still asserts that there is a flat memory space, while the `CodeGenHexagon::CreateBufferPtr` override allows 2-d memory. - Enable tensorization in `test_cache_read_write.py`, using `tir.address_of` to pass the lowered value. Co-authored-by: Adam Straw <[email protected]> * [TIR] Allow buffer_bind_scope of N-d buffers Previously, any `buffer_bind_scope` attribute that provides a view into a non-flat buffer would result in an error. After this commit, `buffer_bind_scope` may be used for non-flat buffers, but use of `arg_buffer->elem_offset` within the body of the bind statement is still an error. The `BufferNode::elem_offset` field represents the offset between the pointer of the backing allocation and the first element of the buffer. This offset is only well-defined for flat memory spaces. * update test to tensorize cache_read `y` (works) and cache_write `z` (fails) * add `split` to allow for tensorization of cache_write of `z` * fix typo and cleanup comment * add back original 1d test_cache_read_write * update comments * format error Co-authored-by: Adam Straw <[email protected]>
1 parent 2009224 commit 57dee39

File tree

5 files changed

+143
-58
lines changed

5 files changed

+143
-58
lines changed

src/target/llvm/codegen_llvm.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,13 +1006,19 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
10061006
} else if (op->op.same_as(builtin::address_of())) {
10071007
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
10081008
ICHECK(op->args.size() == 1 && load);
1009-
ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations.";
1010-
PrimExpr index = load->indices[0];
1011-
if (const RampNode* r = index.as<RampNode>()) {
1012-
index = r->base;
1009+
1010+
Array<PrimExpr> indices = load->indices;
1011+
if (const RampNode* r = indices[indices.size() - 1].as<RampNode>()) {
1012+
indices.Set(indices.size() - 1, r->base);
1013+
}
1014+
1015+
std::vector<llvm::Value*> indices_val;
1016+
for (const auto& index : indices) {
1017+
indices_val.push_back(MakeValue(index));
10131018
}
1019+
10141020
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
1015-
{MakeValue(index)}, load->dtype);
1021+
indices_val, load->dtype);
10161022
unsigned addrspace =
10171023
llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
10181024
return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));

src/tir/ir/buffer.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,6 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
460460
begins = SimplifyArray(&ana, begins);
461461
Array<PrimExpr> elem_offset = n->ElemOffset(begins);
462462
elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); });
463-
ICHECK_EQ(elem_offset.size(), 1) << "MakeSlice currently supports only flat 1-d memory.";
464463

465464
Array<PrimExpr> strides = n->strides;
466465
if (strides.size() == 0) {
@@ -480,8 +479,20 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
480479
return MakeStrideView().MakeSlice(begins, extents);
481480
}
482481
}
483-
return Buffer(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
484-
n->data_alignment, 0, n->buffer_type);
482+
Buffer slice(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
483+
n->data_alignment, 0, n->buffer_type);
484+
485+
// Buffer must be constructed with a singular element offset which means there is no
486+
// support for n-dimensional buffers where n > 1. Insert sentinel value for
487+
// ArgBinder::BindBuffer to state that any usage of element offset is invalid
488+
// in this case. This allows for construction of a Buffer with multiple element offsets
489+
// but disallows any usage of those element offsets. See PR #10816 for discussion on
490+
// supporting multiple element offsets in TIR Buffer.
491+
// TODO(Lunderberg): Remove if/when TIR supports multiple element offsets in TIR Buffer
492+
if (elem_offset.size() != 1) {
493+
slice.CopyOnWrite()->elem_offset = PrimExpr();
494+
}
495+
return slice;
485496
}
486497

487498
PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,

src/tir/transforms/arg_binder.cc

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,25 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st
9696
<< " required_alignment=" << arg->data_alignment
9797
<< ", provided_alignment=" << value->data_alignment;
9898
}
99-
// bind pointer and offset.
100-
if (is_zero(arg->elem_offset)) {
101-
ICHECK(is_zero(value->elem_offset))
102-
<< "Trying to bind a Buffer with offset into one without offset "
103-
<< " required elem_offset=" << arg->elem_offset
104-
<< ", provided elem_offset=" << value->elem_offset;
105-
}
10699

107-
this->Bind(arg->data, value->data, arg_name + ".data");
108-
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
109-
if (arg->offset_factor > 1) {
110-
PrimExpr offset = value->elem_offset;
111-
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
112-
PrimExpr zero = make_zero(offset.dtype());
113-
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
114-
&asserts_);
100+
if (value->elem_offset.defined()) {
101+
// bind pointer and offset.
102+
if (is_zero(arg->elem_offset)) {
103+
ICHECK(is_zero(value->elem_offset))
104+
<< "Trying to bind a Buffer with offset into one without offset "
105+
<< " required elem_offset=" << arg->elem_offset
106+
<< ", provided elem_offset=" << value->elem_offset;
107+
}
108+
109+
this->Bind(arg->data, value->data, arg_name + ".data");
110+
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
111+
if (arg->offset_factor > 1) {
112+
PrimExpr offset = value->elem_offset;
113+
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
114+
PrimExpr zero = make_zero(offset.dtype());
115+
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
116+
&asserts_);
117+
}
115118
}
116119
}
117120

src/tir/transforms/storage_flatten.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,9 @@ class BufferBindUnwrapper : public StmtExprMutator {
887887
}
888888

889889
PrimExpr VisitExpr_(const VarNode* op) final {
890+
ICHECK(!illegal_vars_.count(op)) << "Variable " << op->name_hint << " is not well defined. "
891+
<< "(e.g. use of buffer.elem_offset for a non-flat buffer)";
892+
890893
auto it = var_remap_.find(op);
891894
if (it != var_remap_.end()) {
892895
return it->second;
@@ -1110,6 +1113,11 @@ class BufferBindUnwrapper : public StmtExprMutator {
11101113
// transformations should have been handled in
11111114
// BufferShapeLegalize.
11121115
binder.BindBuffer(source, view, source->name, false);
1116+
if (auto* elem_offset_var = source->elem_offset.as<VarNode>()) {
1117+
if (!view->elem_offset.defined()) {
1118+
illegal_vars_.insert(elem_offset_var);
1119+
}
1120+
}
11131121

11141122
// Apply the remaps
11151123
Stmt body = op->body;
@@ -1162,6 +1170,8 @@ class BufferBindUnwrapper : public StmtExprMutator {
11621170
// The buffer assignment map
11631171
// Variable remap
11641172
std::unordered_map<const VarNode*, PrimExpr> var_remap_;
1173+
// Variables that may not occur within the body.
1174+
std::unordered_set<const VarNode*> illegal_vars_;
11651175
// Buffer map
11661176
std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
11671177
// Set of vars that have occurred in an AllocateNode, but haven't

tests/python/contrib/test_hexagon/test_cache_read_write.py

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929

3030
def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
31-
assert len(shape) == 1
3231
src = te.placeholder(shape=shape, dtype=dtype, name="src")
3332
dst = te.compute(shape, lambda i: src[i], name="dst")
3433
size = shape[0] * np.dtype(dtype).itemsize
@@ -38,30 +37,72 @@ def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
3837
dtype,
3938
scope=src_scope,
4039
offset_factor=1,
40+
name="mem_copy_src_buffer",
4141
)
4242

4343
dst_buffer = tvm.tir.decl_buffer(
4444
shape,
4545
dtype,
4646
scope=dst_scope,
4747
offset_factor=1,
48+
name="mem_copy_dst_buffer",
4849
)
4950

51+
zero_indices = [0 for _ in shape]
52+
5053
def intrin_func(ins, outs):
5154
ib = tvm.tir.ir_builder.create()
5255

5356
_src = ins[0]
5457
_dst = outs[0]
58+
59+
dst_handle = ib.buffer_ptr(dst_buffer)
60+
src_handle = ib.buffer_ptr(src_buffer)
61+
5562
ib.emit(
5663
tvm.tir.call_intrin(
57-
"handle", "tir.mem_copy", _dst.access_ptr("w"), _src.access_ptr("r"), size
64+
"handle",
65+
"tir.mem_copy",
66+
tvm.tir.call_intrin("handle", "tir.address_of", dst_handle[zero_indices]),
67+
tvm.tir.call_intrin("handle", "tir.address_of", src_handle[zero_indices]),
68+
size,
5869
)
5970
)
6071
return ib.get()
6172

6273
return te.decl_tensor_intrin(dst.op, intrin_func, binds={src: src_buffer, dst: dst_buffer})
6374

6475

76+
def verify(hexagon_session, s, x, y, z, size):
77+
print(tvm.lower(s, [x, y, z]))
78+
79+
target_hexagon = tvm.target.hexagon("v68", link_params=True)
80+
func = tvm.build(
81+
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
82+
)
83+
84+
if hexagon_session is None:
85+
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
86+
87+
mod = hexagon_session.load_module(func)
88+
xt = tvm.nd.array(
89+
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
90+
device=hexagon_session.device,
91+
)
92+
yt = tvm.nd.array(
93+
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
94+
device=hexagon_session.device,
95+
)
96+
zt = tvm.nd.array(
97+
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
98+
device=hexagon_session.device,
99+
)
100+
mod["dmacpy"](xt, yt, zt)
101+
102+
ref = xt.numpy() + yt.numpy()
103+
np.testing.assert_equal(zt.numpy(), ref)
104+
105+
65106
@requires_hexagon_toolchain
66107
def test_cache_read_write(hexagon_session):
67108
size = 128
@@ -75,52 +116,66 @@ def test_cache_read_write(hexagon_session):
75116
z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
76117
s = te.create_schedule(z.op)
77118

78-
x_global = s.cache_read(x, "global.vtcm", [z])
79-
y_global = s.cache_read(y, "global.vtcm", [z])
80-
z_global = s.cache_write(z, "global.vtcm")
119+
x_vtcm = s.cache_read(x, "global.vtcm", [z])
120+
y_vtcm = s.cache_read(y, "global.vtcm", [z])
121+
z_vtcm = s.cache_write(z, "global.vtcm")
81122

82-
zouter, zinner = s[z_global].split(z_global.op.axis[0], factor=factor)
123+
zouter, zinner = s[z_vtcm].split(z_vtcm.op.axis[0], factor=factor)
83124

84-
s[x_global].compute_at(s[z_global], zouter)
85-
s[y_global].compute_at(s[z_global], zouter)
125+
s[x_vtcm].compute_at(s[z_vtcm], zouter)
126+
s[y_vtcm].compute_at(s[z_vtcm], zouter)
86127

87128
mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")
88129

89-
(cache_read_x,) = s[x_global].op.axis
90-
s[x_global].tensorize(cache_read_x, mem_copy_read)
130+
(cache_read_x,) = s[x_vtcm].op.axis
131+
s[x_vtcm].tensorize(cache_read_x, mem_copy_read)
91132

92-
(cache_read_y,) = s[y_global].op.axis
93-
s[y_global].tensorize(cache_read_y, mem_copy_read)
133+
(cache_read_y,) = s[y_vtcm].op.axis
134+
s[y_vtcm].tensorize(cache_read_y, mem_copy_read)
94135

95136
mem_copy_write = intrin_mem_copy(outer_shape, dtype, "global", "global.vtcm")
96137

97138
(cache_write_z,) = s[z].op.axis
98139
s[z].tensorize(cache_write_z, mem_copy_write)
99140

100-
print(tvm.lower(s, [x, y, z]))
141+
verify(hexagon_session, s, x, y, z, size)
101142

102-
target_hexagon = tvm.target.hexagon("v68", link_params=True)
103-
func = tvm.build(
104-
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
105-
)
106143

107-
if hexagon_session is None:
108-
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
144+
def layout_transform_2d(n):
145+
return [n // 16, te.AXIS_SEPARATOR, n % 16]
109146

110-
mod = hexagon_session.load_module(func)
111-
xt = tvm.nd.array(
112-
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
113-
device=hexagon_session.device,
114-
)
115-
yt = tvm.nd.array(
116-
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
117-
device=hexagon_session.device,
118-
)
119-
zt = tvm.nd.array(
120-
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
121-
device=hexagon_session.device,
122-
)
123-
mod["dmacpy"](xt, yt, zt)
124147

125-
ref = xt.numpy() + yt.numpy()
126-
np.testing.assert_equal(zt.numpy(), ref)
148+
@requires_hexagon_toolchain
149+
def test_cache_read_write_2d(hexagon_session):
150+
size = 128
151+
outer_shape = (size,)
152+
factor = 16
153+
inner_shape = (factor,)
154+
dtype = "int8"
155+
156+
x = te.placeholder(shape=outer_shape, dtype=dtype, name="x")
157+
y = te.placeholder(shape=outer_shape, dtype=dtype, name="y")
158+
z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
159+
s = te.create_schedule(z.op)
160+
161+
x_vtcm = s.cache_read(x, "global.vtcm", [z])
162+
y_vtcm = s.cache_read(y, "global.vtcm", [z])
163+
z_vtcm = s.cache_write(z, "global.vtcm")
164+
165+
layout_x_vtcm = s[x_vtcm].transform_layout(layout_transform_2d)
166+
layout_y_vtcm = s[y_vtcm].transform_layout(layout_transform_2d)
167+
layout_z_vtcm = s[z_vtcm].transform_layout(layout_transform_2d)
168+
169+
mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")
170+
s[x_vtcm].tensorize(layout_x_vtcm[1], mem_copy_read)
171+
s[y_vtcm].tensorize(layout_y_vtcm[1], mem_copy_read)
172+
173+
# The loop schedule over `z` is not modified when calling `transform_layout`
174+
# on `z_vtcm` above therefore we must call `split` to modify the loop schedule
175+
# over `z` to match the layout of `z_vtcm` such that we can accurately write
176+
# `z_vtcm` back to `z` using memory copy intrinsic
177+
zouter, zinner = s[z].split(z.op.axis[0], factor=factor)
178+
mem_copy_write = intrin_mem_copy(inner_shape, dtype, "global", "global.vtcm")
179+
s[z].tensorize(zinner, mem_copy_write)
180+
181+
verify(hexagon_session, s, x, y, z, size)

0 commit comments

Comments
 (0)