Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 100 additions & 78 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ class BuiltinLower : public StmtExprMutator {
// - arg stack
//
// Scoping and liveness rules:
// - The function take a root scope.
// - Every parallel for introduce a new scope.
// - Every call_packed introduce a new scope.
// - The memory allocated by tvm_stack_make_array/make_shape will
// no longer become valid outside the scope (and may be reused by
// subsequent call_packed.
// - TODO(tvm-team): we might consider a root scope so stack_make_shape
// can be called out-side call_packed.
//
// Example:
// {
Expand Down Expand Up @@ -100,6 +100,42 @@ class BuiltinLower : public StmtExprMutator {
}
};

// Context manager object to help maintain current scope's run sizes.
// It collect max size usage and rewind the stack size automatically at exit.
class ScopeRunSizesGuard {
public:
using ScopeStack = std::vector<AllocaScope>;

explicit ScopeRunSizesGuard(ScopeStack* stack, bool is_precheck)
: stack_(stack), is_precheck_(is_precheck) {}

private:
ScopeStack* stack_;
bool is_precheck_;
size_t frame_pos;
StackSizes restore;

void EnterWithScope() {
ICHECK(!stack_->empty());
frame_pos = stack_->size() - 1;
restore = stack_->back().run_sizes;
}

void ExitWithScope() {
ICHECK_EQ(frame_pos, stack_->size() - 1);
auto& scope = stack_->back();
// Verify stack size matches earlier value.
if (is_precheck_) {
scope.UpdateMax();
} else {
scope.AssertMaxIsValid();
}
scope.run_sizes = restore;
}

friend class With<ScopeRunSizesGuard>;
};

Stmt Build(Stmt stmt) { return this->VisitBodyAndRealizeAlloca(stmt); }

StackSizes GetMaxStack(Stmt stmt) {
Expand All @@ -118,7 +154,10 @@ class BuiltinLower : public StmtExprMutator {
decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode");
}

precheck.VisitStmt(stmt);
{
With<ScopeRunSizesGuard> guard(&precheck.alloca_scope_, true);
precheck.VisitStmt(stmt);
}

ICHECK_EQ(precheck.alloca_scope_.size(), 1);
return precheck.alloca_scope_[0].max_sizes;
Expand Down Expand Up @@ -166,7 +205,10 @@ class BuiltinLower : public StmtExprMutator {
}
}

stmt = this->VisitStmt(stmt);
{
With<ScopeRunSizesGuard> guard(&alloca_scope_, is_precheck_);
stmt = this->VisitStmt(stmt);
}

ICHECK(!alloca_scope_.empty());
alloca_scope_.pop_back();
Expand All @@ -180,20 +222,9 @@ class BuiltinLower : public StmtExprMutator {

auto scope_size = alloca_scope_.size();
auto stmt = StmtExprMutator::VisitStmt(s);
{
// NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_.
auto& scope = alloca_scope_.back();
// This invariant asserts the assumption that
// make_stack_shape only happens within a call_packed.
// We could relax this in the future if we want to
// introduce root scope as a separate scope
ICHECK_EQ(alloca_scope_.size(), scope_size)
<< "alloca_scope_ length is different before and after recursion";
ICHECK_EQ(scope.run_sizes.shape_stack, -1)
<< "Expect no tvm_stack_make_shape outside of CallNodes";
ICHECK_EQ(scope.run_sizes.array_stack, 0)
<< "Expect no tvm_stack_make_array outside of CallNodes";
}

ICHECK_EQ(alloca_scope_.size(), scope_size)
<< "alloca_scope_ length is different before and after recursion";

auto prep_seq = std::move(prep_seq_stack_.back());
prep_seq_stack_.pop_back();
Expand Down Expand Up @@ -465,49 +496,43 @@ class BuiltinLower : public StmtExprMutator {
PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) {
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

int64_t restore_shape_stack = scope.run_sizes.shape_stack;
size_t restore_array_stack = scope.run_sizes.array_stack;
size_t arg_stack_begin = scope.run_sizes.arg_stack;

size_t arg_count = op->args.size();

// cpacked expects a resource_handle parameter
if (!use_string_lookup) {
arg_count--;
}

scope.run_sizes.arg_stack += arg_count;
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < arg_count; ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
{
With<ScopeRunSizesGuard> guard(&alloca_scope_, is_precheck_);

scope.run_sizes.arg_stack += arg_count;

// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();

for (size_t i = 1; i < arg_count; ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
}
// Verify stack size matches earlier value.
if (is_precheck_) {
scope.UpdateMax();
} else {
scope.AssertMaxIsValid();
}
scope.run_sizes.shape_stack = restore_shape_stack;
scope.run_sizes.array_stack = restore_array_stack;
scope.run_sizes.arg_stack = arg_stack_begin;

Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
Expand All @@ -533,38 +558,35 @@ class BuiltinLower : public StmtExprMutator {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

int64_t restore_shape_stack = scope.run_sizes.shape_stack;
size_t restore_array_stack = scope.run_sizes.array_stack;
size_t arg_stack_begin = scope.run_sizes.arg_stack;
scope.run_sizes.arg_stack += op->args.size();

size_t args_size = op->args.size();
ICHECK_GT(args_size, 0);
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast(api_type, arg);

scope.run_sizes.arg_stack += op->args.size();

{
With<ScopeRunSizesGuard> guard(&alloca_scope_, is_precheck_);

PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
}
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
}
// Verify stack size matches earlier value.
if (is_precheck_) {
scope.UpdateMax();
} else {
scope.AssertMaxIsValid();
}
scope.run_sizes.shape_stack = restore_shape_stack;
scope.run_sizes.array_stack = restore_array_stack;

// Update the top of the stack, so we can use more than one
// packed function's arguments with the one stack.
scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1;
Expand Down
83 changes: 83 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,88 @@ def before():
expected = tvm.TVMError


class TestMakeArrayInRootScope(tvm.testing.CompareBeforeAfter):

transform = tvm.tir.transform.LowerTVMBuiltin()

@T.prim_func
def before(var_placeholder: T.handle, var_extern: T.handle):
T.func_attr({"target": T.target("llvm")})
placeholder = T.match_buffer(var_placeholder, (1, 3, 4, 4), offset_factor=1)
extern = T.match_buffer(var_extern, (1, 3, 8, 8), offset_factor=1)
# T.attr("dummy", "device_type", 1)
# T.attr("dummy", "device_id", 0)
T.call_extern(
"void",
"resize2d_with_dltensor", # suppose a c function accept dltensor
T.tvm_stack_make_array(
placeholder.data,
T.tvm_stack_make_shape(1, 3, 4, 4),
0,
4,
T.float32(0),
placeholder.elem_offset,
),
T.tvm_stack_make_array(
extern.data,
T.tvm_stack_make_shape(1, 3, 8, 8),
0,
4,
T.float32(0),
extern.elem_offset,
),
)

@T.prim_func
def expected(var_placeholder: T.handle, var_extern: T.handle):
T.func_attr(
{
"global_symbol": "main",
"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
}
)
placeholder = T.match_buffer(var_placeholder, (1, 3, 4, 4), offset_factor=1)
extern = T.match_buffer(var_extern, (1, 3, 8, 8), offset_factor=1)
stack_array: T.handle = T.tvm_stack_alloca("array", 2)
stack_shape: T.handle("int64") = T.tvm_stack_alloca("shape", 8)
stack_shape_1 = T.decl_buffer((T.int64(8),), "int64", data=stack_shape)
stack_shape_1[0] = T.int64(1)
stack_shape_1[1] = T.int64(3)
stack_shape_1[2] = T.int64(4)
stack_shape_1[3] = T.int64(4)
T.tvm_struct_set(stack_array, 0, 1, placeholder.data)
stack_shape_2 = T.Buffer((1,), "int64", data=stack_shape)
T.tvm_struct_set(stack_array, 0, 2, T.address_of(stack_shape_2[0]))
T.tvm_struct_set(stack_array, 0, 3, T.reinterpret("handle", T.uint64(0)))
T.tvm_struct_set(stack_array, 0, 4, 4)
T.tvm_struct_set(stack_array, 0, 5, T.uint8(2))
T.tvm_struct_set(stack_array, 0, 6, T.uint8(32))
T.tvm_struct_set(stack_array, 0, 7, T.uint16(1))
T.tvm_struct_set(stack_array, 0, 8, T.Cast("uint64", placeholder.elem_offset * 4))
T.tvm_struct_set(stack_array, 0, 9, 0)
T.tvm_struct_set(stack_array, 0, 10, 1)
stack_shape_1[4] = T.int64(1)
stack_shape_1[5] = T.int64(3)
stack_shape_1[6] = T.int64(8)
stack_shape_1[7] = T.int64(8)
T.tvm_struct_set(stack_array, 1, 1, extern.data)
stack_shape_3 = T.Buffer((5,), "int64", data=stack_shape)
T.tvm_struct_set(stack_array, 1, 2, T.address_of(stack_shape_3[4]))
T.tvm_struct_set(stack_array, 1, 3, T.reinterpret("handle", T.uint64(0)))
T.tvm_struct_set(stack_array, 1, 4, 4)
T.tvm_struct_set(stack_array, 1, 5, T.uint8(2))
T.tvm_struct_set(stack_array, 1, 6, T.uint8(32))
T.tvm_struct_set(stack_array, 1, 7, T.uint16(1))
T.tvm_struct_set(stack_array, 1, 8, T.Cast("uint64", extern.elem_offset * 4))
T.tvm_struct_set(stack_array, 1, 9, 0)
T.tvm_struct_set(stack_array, 1, 10, 1)
T.call_extern(
"void",
"resize2d_with_dltensor",
T.tvm_struct_get(stack_array, 0, 0, "handle"),
T.tvm_struct_get(stack_array, 1, 0, "handle"),
)


if __name__ == "__main__":
tvm.testing.main()