Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,16 @@ class PrimFuncNode : public BaseFuncNode {
* will make program analysis much easier.
*/
Map<tir::Var, Buffer> buffer_map;
/*! \brief The resource handle to be used by the function when accessing platform resources */
tir::Var resource_handle;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like this should be a property of the call site rather than a function. how come you want to add it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is referring to the variable used in the function to pass into further calls, similar to the params above except we don't treat it as a parameter which would get packed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, but is there just one resource_handle associated with each function? suppose you had two instances of the same accelerator and wanted to launch the same compute function twice concurrently? also, what about the AOT top-level function, which should have a struct like TVMDevices_my_model?


void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("attrs", &attrs);
v->Visit("resource_handle", &resource_handle);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand All @@ -105,7 +108,7 @@ class PrimFuncNode : public BaseFuncNode {
// visit params and buffer_map first as they contains defs.
return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
equal(ret_type, other->ret_type) && equal(body, other->body) &&
equal(attrs, other->attrs);
equal(attrs, other->attrs) && equal.DefEqual(resource_handle, other->resource_handle);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -114,6 +117,7 @@ class PrimFuncNode : public BaseFuncNode {
hash_reduce(ret_type);
hash_reduce(body);
hash_reduce(attrs);
hash_reduce.DefHash(resource_handle);
}
/*!
* \brief Return the derived function annotation of this function.
Expand Down Expand Up @@ -141,11 +145,14 @@ class PrimFunc : public BaseFunc {
* \param ret_type The return type of the function.
* \param buffer_map The buffer map for parameter buffer unpacking.
* \param attrs Additional function attributes.
* \param resource_handle Handle for passing resources to the function
* \param span The location of this object in the source code.
*/
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
DictAttrs attrs = NullValue<DictAttrs>(),
tir::Var resource_handle = tir::Var("resource_handle", DataType::Handle()),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
Expand Down
29 changes: 27 additions & 2 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,31 @@ class PrimFunc(BaseFunc):
attrs: Optional[tvm.Attrs]
Attributes of the function, can be None

resource_handle: Optional[tvm.tir.Var]
The resource handle to be used by the function when accessing platform resources,
if not passed a Var will be created for it

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, span=None):
def __init__(
self,
params,
body,
ret_type=None,
buffer_map=None,
attrs=None,
resource_handle=None,
span=None,
):
param_list = []
buffer_map = {} if buffer_map is None else buffer_map

# This is bound later as it relies on the FFI API having defined "Var"
if resource_handle is None:
resource_handle = Var("resource_handle", dtype="handle")

for x in params:
x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
if isinstance(x, Buffer):
Expand All @@ -67,7 +85,14 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa
raise TypeError("params can only contain Var or Buffer")

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore
_ffi_api.PrimFunc, # type: ignore
param_list,
body,
ret_type,
buffer_map,
attrs,
resource_handle,
span,
)

def with_body(self, new_body, span=None):
Expand Down
6 changes: 5 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ class AOTExecutorCodegen : public ExprVisitor {
auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked();
if (use_unpacked_api_) {
calling_pattern = tvm::tir::builtin::call_extern();
args.push_back(resource_handle_);
}

create_func_call_stmts.push_back(
Expand Down Expand Up @@ -643,14 +644,16 @@ class AOTExecutorCodegen : public ExprVisitor {

// Make the PrimFunc
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
DictAttrs(dict_attrs), resource_handle_);
}

protected:
/*! \brief mod */
runtime::Module* mod_;
/*! \brief list of input expressions (i.e., variable passed by the user) */
std::vector<Var> input_vars_;
/*! \brief resource handle to be passed into operator functions */
tir::Var resource_handle_;
/*! \brief input and output variables belonging to the main function signature */
Array<tir::Var> main_signature_;
/*! \brief target device */
Expand Down Expand Up @@ -699,6 +702,7 @@ class AOTExecutorCodegen : public ExprVisitor {
public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
: mod_(mod),
resource_handle_("resource_handle", DataType::Handle()),
targets_(targets),
target_host_(target_host),
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
Expand Down
22 changes: 6 additions & 16 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,29 +195,19 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name,
const std::string& run_func) {
code_ << "TVM_DLL int32_t " << run_func << "(";
unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs);
for (unsigned int i = 0; i < total_args; ++i) {
code_ << "void* arg" << i;
if (i + 1 != total_args) {
code_ << ",";
}
int total_args = (metadata_->inputs.size() + metadata_->num_outputs);
for (int i = 0; i < total_args; ++i) {
code_ << "void* arg" << i << ",";
}
code_ << ");\n";
code_ << "void* resource_handle);\n";
code_ << "int32_t " << entrypoint_name;
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
"out_type_code, void* resource_handle) {\n";
code_ << "return " << run_func << "(";
for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
for (int i = 0; i < total_args; ++i) {
code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
}
for (int i = 0; i < metadata_->num_outputs; ++i) {
int j = metadata_->inputs.size() + i;
code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data";
if (i + 1 != metadata_->num_outputs) {
code_ << ",";
}
}
code_ << ");\n";
code_ << "resource_handle);\n";
code_ << "}\n";
}

Expand Down
9 changes: 6 additions & 3 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ LinkedParam::LinkedParam(int64_t id, ::tvm::runtime::NDArray param) {

// Get the function type of a PrimFunc
PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, tir::Var resource_handle,
Span span) {
// Assume void-return type for now
// TODO(tvm-team) consider type deduction from body.
if (!ret_type.defined()) {
Expand All @@ -50,6 +51,7 @@ PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
n->buffer_map = std::move(buffer_map);
n->attrs = std::move(attrs);
n->checked_type_ = n->func_type_annotation();
n->resource_handle = std::move(resource_handle);
n->span = std::move(span);
data_ = std::move(n);
}
Expand Down Expand Up @@ -81,8 +83,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_GLOBAL("tir.PrimFunc")
.set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, tir::Var resource_handle,
Span span) {
return PrimFunc(params, body, ret_type, buffer_map, attrs, resource_handle, span);
});

} // namespace tir
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
// add signature for packed arguments.
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
Expand Down
3 changes: 3 additions & 0 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
args.push_back(v_arg);
}

// Add resource handle to function parameters
args.push_back(func_ptr->resource_handle);

// Bind variables then bind buffers to them to ensure correct ordering
for (const auto& kv : var_def) {
binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
Expand Down
19 changes: 13 additions & 6 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def test_fails_if_no_target(mod_without_attrs):
def test_device_setup(mod, target, dev):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod)
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert len(f.params) == 2
assert f.params[0].name == "arg0"
assert f.params[1].name == "resource_handle"
assert f.body.node == "default"
assert f.body.attr_key == "device_id"
assert f.body.value == 0
Expand All @@ -76,15 +77,18 @@ def test_no_buffers_no_device_setup():
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)

f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert len(f.params) == 2
assert f.params[0].name == "arg0"
assert f.params[1].name == "resource_handle"
assert f.body.var.name == "A"
assert f.body.value.name == "arg0"


def test_argument_mapping(mod):
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert len(f.params) == 2
assert f.params[0].name == "arg0"
assert f.params[1].name == "resource_handle"
assert f.body.body.body.var.name == "A"
assert f.body.body.body.value.name == "arg0"

Expand All @@ -100,9 +104,10 @@ def test_argument_mapping_multiple():
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)

f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 2
assert len(f.params) == 3
assert f.params[0].name == "arg0"
assert f.params[1].name == "arg1"
assert f.params[2].name == "resource_handle"
assert f.body.body.body.var.name == "A"
assert f.body.body.body.value.name == "arg0"
assert f.body.body.body.body.var.name == "B"
Expand All @@ -119,9 +124,10 @@ def test_argument_mapping_multiple_matching():
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)

f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 2
assert len(f.params) == 3
assert f.params[0].name == "arg0"
assert f.params[1].name == "arg1"
assert f.params[2].name == "resource_handle"
assert f.body.body.body.var.name == "A"
assert f.body.body.body.value.name == "arg0"
assert f.body.body.body.body.condition.a.name == "A"
Expand All @@ -139,10 +145,11 @@ def test_body():
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 3
assert len(f.params) == 4
assert f.params[0].name == "arg0"
assert f.params[1].name == "arg1"
assert f.params[2].name == "arg2"
assert f.params[3].name == "resource_handle"
assert f.body.body.body.var.name == "A"
assert f.body.body.body.value.name == "arg2"
assert f.body.body.body.body.var.name == "B"
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/task_mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ mypy --check-untyped-defs python/tvm/tir/schedule
echo "Checking MyPy Type defs in the analysis package."
mypy --check-untyped-defs python/tvm/tir/analysis/

echo "Checking MyPy Type defs in the transofrm package."
echo "Checking MyPy Type defs in the transform package."
mypy --check-untyped-defs python/tvm/tir/transform/