Skip to content

Commit 094428d

Browse files
committed
[Relax] Implement R.ensure_aligned and update memory planning for R.view
1 parent 0fc047c commit 094428d

File tree

9 files changed

+185
-77
lines changed

9 files changed

+185
-77
lines changed

python/tvm/relax/op/memory/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"""Relax memory primitives."""
1818

1919
from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
20-
from .view import view
20+
from .view import view, ensure_aligned

python/tvm/relax/op/memory/view.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,20 @@ def _normalize(expr, relax_cls):
9292
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)
9393

9494
return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore
95+
96+
97+
def ensure_aligned(data: Expr) -> Expr:
98+
"""
99+
Ensure the tensor has elem_offset == 0. A copy will be made if necessary.
100+
101+
Parameters
102+
----------
103+
data : relax.Expr
104+
The input tensor
105+
106+
Results
107+
-------
108+
result : relax.Expr
109+
The aligned tensor
110+
"""
111+
return _ffi_api.ensure_aligned(data) # type: ignore

src/relax/backend/vm/vm_builtin_lower.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class VMBuiltinLowerMutator : public ExprMutator {
4747
return Reshape(call);
4848
} else if (call->op == shape_of_op_) {
4949
return ShapeOf(call);
50+
} else if (call->op == view_op_) {
51+
return View(call);
52+
} else if (call->op == ensure_aligned_op_) {
53+
return EnsureAligned(call);
5054
} else if (call->op == to_vdevice_op_) {
5155
return ToDevice(call);
5256
} else if (call->op == make_closure_op_) {
@@ -124,6 +128,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
124128
}
125129
}
126130

131+
Expr View(const Call& view_node) {
132+
StructInfoDeriveFunc infer_sinfo_env_func;
133+
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
134+
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);
135+
ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
136+
return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo});
137+
}
138+
139+
Expr EnsureAligned(const Call& call_node) {
140+
ICHECK(call_node->args.size() == 1);
141+
return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)});
142+
}
143+
127144
Expr ShapeOf(const Call& call_node) {
128145
ICHECK(call_node->args.size() == 1);
129146
ICHECK(call_node->struct_info_.defined());
@@ -188,6 +205,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
188205
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
189206
const Op& reshape_op_ = Op::Get("relax.reshape");
190207
const Op& shape_of_op_ = Op::Get("relax.shape_of");
208+
const Op& view_op_ = Op::Get("relax.memory.view");
209+
const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned");
191210
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
192211
const Op& make_closure_op_ = Op::Get("relax.make_closure");
193212
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
@@ -208,6 +227,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
208227
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
209228
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
210229
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
230+
const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"};
211231
};
212232

213233
Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }

src/relax/op/memory/view.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,12 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
334334
relative_byte_offset = relax::PrimValue::Int64(0);
335335
}
336336

337-
StructInfoDeriveFunc infer_sinfo_env_func;
338-
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
339-
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);
340-
341-
ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
337+
if (shape.same_as(call->args[1]) && dtype.same_as(call->args[2]) &&
338+
relative_byte_offset.same_as(call->args[3])) {
339+
return call;
340+
}
342341

343-
return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset});
342+
return Call(call->op, {data, shape, dtype, relative_byte_offset});
344343
}
345344

346345
TVM_REGISTER_OP("relax.memory.view")
@@ -355,5 +354,28 @@ TVM_REGISTER_OP("relax.memory.view")
355354
.set_attr<FLegalize>("FLegalize", LegalizeView)
356355
.set_attr<Bool>("FPurity", Bool(true));
357356

357+
Expr ensure_aligned(const Expr& x) {
358+
static const Op& op = Op::Get("relax.memory.ensure_aligned");
359+
return Call(op, {x});
360+
}
361+
362+
TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned);
363+
364+
StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) {
365+
if (call->args.size() != 1) {
366+
ctx->ReportFatal(Diagnostic::Error(call)
367+
<< "Operator " << call->op << " should receive 1 argument, "
368+
<< "but received " << call->args);
369+
}
370+
return GetStructInfo(call->args[0]);
371+
}
372+
373+
TVM_REGISTER_OP("relax.memory.ensure_aligned")
374+
.set_num_inputs(1)
375+
.add_argument("x", "Tensor", "The input tensor.")
376+
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
377+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEnsureAligned)
378+
.set_attr<Bool>("FPurity", Bool(true));
379+
358380
} // namespace relax
359381
} // namespace tvm

src/relax/op/memory/view.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ namespace relax {
3232
/*! \brief View a tensor with different properties. */
3333
Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr> relative_byte_offset);
3434

35+
/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */
36+
Expr ensure_aligned(const Expr& x);
37+
3538
} // namespace relax
3639
} // namespace tvm
3740

src/relax/transform/static_plan_block_memory.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,13 @@ class TokenAllocator1D {
286286
std::vector<StorageToken> full_pool_;
287287
};
288288

289-
/*! \brief Check if the input op is "relax.reshape". */
290-
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
289+
/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */
290+
bool IsInplaceMemoryOp(const Expr& op) {
291+
static const Op& reshape_op = Op::Get("relax.reshape");
292+
static const Op& view_op = Op::Get("relax.memory.view");
293+
static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned");
294+
return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op);
295+
}
291296

292297
/*! \brief The base class for the storage allocation visitor. */
293298
class StorageAllocatorBaseVisitor : public ExprVisitor {
@@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
498503
// Create a storage token for builtin alloc_tensor.
499504
this->CreateToken(call);
500505
return;
501-
} else if (IsReshape(call->op)) {
506+
} else if (IsInplaceMemoryOp(call->op)) {
502507
// Reuse the input's token for builtin reshape.
503508
SetTokens(call, GetTokens(call->args[0]));
504509
return;
@@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
751756
block_tokens.push_back(new_token.get());
752757
}
753758
return;
754-
} else if (IsReshape(call->op)) {
759+
} else if (IsInplaceMemoryOp(call->op)) {
755760
Tokens tokens = GetTokens(call->args[0]);
756761
ICHECK(!tokens.IsNested());
757762
if (tokens.IsLeaf()) {

src/runtime/relax_vm/builtin.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,19 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
545545
return ShapeTuple(out_shape);
546546
});
547547

548+
TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) {
549+
if (data->byte_offset == 0) {
550+
return data;
551+
}
552+
DLManagedTensor* dl_tensor = data.ToDLPack();
553+
dl_tensor->dl_tensor.data =
554+
reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset;
555+
dl_tensor->dl_tensor.byte_offset = 0;
556+
// For platforms that does not support pointer arithmetic, we need to copy the data to a new
557+
// buffer.
558+
return NDArray::FromDLPack(dl_tensor);
559+
});
560+
548561
} // namespace relax_vm
549562
} // namespace runtime
550563
} // namespace tvm

tests/python/relax/test_op_view.py

Lines changed: 39 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -483,18 +483,7 @@ def main(A: R.Tensor([4096], "float32")):
483483
class Expected:
484484
@R.function
485485
def main(A: R.Tensor([4096], "float32")):
486-
B = R.ExternFunc(
487-
"runtime.TVMArrayCreateView",
488-
R.Callable(
489-
derive_func="tvm.relax.struct_info.infer_view_sinfo",
490-
purity=True,
491-
),
492-
)(
493-
A,
494-
R.shape([64, 64]),
495-
R.dtype("float32"),
496-
R.prim_value(0),
497-
)
486+
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
498487
return B
499488

500489
After = tvm.relax.transform.LegalizeOps()(Before)
@@ -515,18 +504,7 @@ def main(A: R.Tensor(dtype="float32")):
515504
class Expected:
516505
@R.function
517506
def main(A: R.Tensor(dtype="float32")):
518-
B = R.ExternFunc(
519-
"runtime.TVMArrayCreateView",
520-
R.Callable(
521-
derive_func="tvm.relax.struct_info.infer_view_sinfo",
522-
purity=True,
523-
),
524-
)(
525-
A,
526-
R.shape([64, 64]),
527-
R.dtype("float32"),
528-
R.prim_value(0),
529-
)
507+
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
530508
return B
531509

532510
After = tvm.relax.transform.LegalizeOps()(Before)
@@ -545,17 +523,8 @@ def main(A: R.Tensor([4096], "float32")):
545523
class Expected:
546524
@R.function
547525
def main(A: R.Tensor([4096], "float32")):
548-
B = R.ExternFunc(
549-
"runtime.TVMArrayCreateView",
550-
R.Callable(
551-
derive_func="tvm.relax.struct_info.infer_view_sinfo",
552-
purity=True,
553-
),
554-
)(
555-
A,
556-
R.shape([4096]),
557-
R.dtype("int32"),
558-
R.prim_value(0),
526+
B = R.memory.view(
527+
A, dtype=R.dtype("int32"), shape=R.shape([4096]), relative_byte_offset=0
559528
)
560529
return B
561530

@@ -575,17 +544,8 @@ def main(A: R.Tensor([4096], "float32")):
575544
class Expected:
576545
@R.function
577546
def main(A: R.Tensor([4096], "float32")):
578-
B = R.ExternFunc(
579-
"runtime.TVMArrayCreateView",
580-
R.Callable(
581-
derive_func="tvm.relax.struct_info.infer_view_sinfo",
582-
purity=True,
583-
),
584-
)(
585-
A,
586-
R.shape([4096]),
587-
R.dtype("float32"),
588-
R.prim_value(0),
547+
B = R.memory.view(
548+
A, relative_byte_offset=R.prim_value(0), shape=R.shape([4096]), dtype="float32"
589549
)
590550
return B
591551

@@ -624,29 +584,17 @@ def main(A: R.Tensor([4096], "uint8")):
624584
class Expected:
625585
@R.function
626586
def main(A: R.Tensor([4096], "uint8")):
627-
B = R.ExternFunc(
628-
"runtime.TVMArrayCreateView",
629-
R.Callable(
630-
derive_func="tvm.relax.struct_info.infer_view_sinfo",
631-
purity=True,
632-
),
633-
)(
587+
B = R.memory.view(
634588
A,
635-
R.shape([512]),
636-
R.dtype("int32"),
637-
R.prim_value(0),
589+
shape=R.shape([512]),
590+
dtype=R.dtype("int32"),
591+
relative_byte_offset=R.prim_value(0),
638592
)
639-
C = R.ExternFunc(
640-
"runtime.TVMArrayCreateView",
641-
R.Callable(
642-
derive_func="tvm.relax.struct_info.infer_view_sinfo",
643-
purity=True,
644-
),
645-
)(
593+
C = R.memory.view(
646594
A,
647-
R.shape([16, 64]),
648-
R.dtype("float16"),
649-
R.prim_value(2048),
595+
shape=R.shape([16, 64]),
596+
dtype=R.dtype("float16"),
597+
relative_byte_offset=R.prim_value(2048),
650598
)
651599
return (B, C)
652600

@@ -772,5 +720,30 @@ def main(A: R.Tensor([4096], "uint8")):
772720
tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1])
773721

774722

723+
@tvm.testing.parametrize_targets("llvm", "cuda")
724+
def test_execute_view_with_new_byte_offset_ensure_aligned(target, dev):
725+
@I.ir_module
726+
class Module:
727+
@R.function
728+
def main(A: R.Tensor([4096], "float32")):
729+
B = R.memory.view(
730+
A,
731+
shape=R.shape([16, 64]),
732+
relative_byte_offset=32 * 64 * 4,
733+
)
734+
C = R.memory.ensure_aligned(B)
735+
return C
736+
737+
built = tvm.relax.build(Module, target=target)
738+
vm = tvm.relax.VirtualMachine(built, device=dev)
739+
740+
np_input = np.random.random([4096]).astype("float32")
741+
tvm_input = tvm.nd.array(np_input, dev)
742+
tvm_output = vm["main"](tvm_input)
743+
np_expected = np_input.reshape(64, 64)[32:48, :]
744+
745+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
746+
747+
775748
if __name__ == "__main__":
776749
tvm.testing.main()

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,5 +1449,60 @@ def main(
14491449
tvm.ir.assert_structural_equal(mod, Expected)
14501450

14511451

1452+
def test_view():
1453+
@I.ir_module
1454+
class Before:
1455+
@T.prim_func
1456+
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
1457+
T.evaluate(0)
1458+
1459+
@R.function
1460+
def main():
1461+
cls = Before
1462+
x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0)
1463+
x1 = R.memory.view(x, [128], "float32", 0)
1464+
x2 = R.memory.ensure_aligned(x1)
1465+
y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0)
1466+
cls.tir_exp(x2, y)
1467+
z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0)
1468+
cls.tir_exp(y, z)
1469+
return z
1470+
1471+
@I.ir_module
1472+
class Expected:
1473+
@T.prim_func
1474+
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
1475+
T.evaluate(0)
1476+
1477+
@R.function
1478+
def main() -> R.Tensor((128,), dtype="float32"):
1479+
cls = Expected
1480+
storage: R.Object = R.memory.alloc_storage(
1481+
R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32")
1482+
)
1483+
x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor(
1484+
storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32")
1485+
)
1486+
x1: R.Tensor((128,), dtype="float32") = R.memory.view(
1487+
x, R.shape([128]), R.dtype("float32"), R.prim_value(0)
1488+
)
1489+
x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1)
1490+
storage1: R.Object = R.memory.alloc_storage(
1491+
R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32")
1492+
)
1493+
y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
1494+
storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
1495+
)
1496+
cls.tir_exp(x2, y)
1497+
z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor(
1498+
R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global")
1499+
)
1500+
cls.tir_exp(y, z)
1501+
return z
1502+
1503+
after = relax.transform.StaticPlanBlockMemory()(Before)
1504+
tvm.ir.assert_structural_equal(after, Expected)
1505+
1506+
14521507
if __name__ == "__main__":
14531508
tvm.testing.main()

0 commit comments

Comments
 (0)