Skip to content

Commit 4852ca2

Browse files
committed
Improved tvm::GetType for tvm_access_ptr and address_of
These `Call` instances can return a `PointerType(PrimType(pointee_dtype))` rather than a `PrimType(DataType::Handle())`.
1 parent 4a1ab48 commit 4852ca2

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

src/tir/op/op.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) {
7070
return ptr->type_annotation;
7171
}
7272
}
73+
74+
if (auto* access = expr.as<tir::CallNode>()) {
75+
if (access->op.same_as(builtin::tvm_access_ptr())) {
76+
ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments";
77+
auto type_annotation = Downcast<Call>(access->args[0]);
78+
static auto builtin_op = Op::Get("tir.type_annotation");
79+
ICHECK(type_annotation->op.same_as(builtin_op))
80+
<< "Expected the first argument of builtin tvm_access_ptr() "
81+
<< "to be a type annotation, but found " << type_annotation->op;
82+
return PointerType(PrimType(type_annotation->dtype));
83+
}
84+
}
85+
86+
if (auto* address_of = expr.as<tir::CallNode>()) {
87+
if (address_of->op.same_as(builtin::address_of())) {
88+
ICHECK_EQ(address_of->args.size(), 1)
89+
<< "Builtin address_of() expects a single argument, but received arguments "
90+
<< address_of->args;
91+
auto* address = address_of->args[0].as<BufferLoadNode>();
92+
ICHECK(address)
93+
<< "Builtin address_of() expects the argument to be a BufferLoad, but received argument "
94+
<< address_of->args[0];
95+
96+
return PointerType(PrimType(address->dtype));
97+
}
98+
}
7399
// Default: return the type indicated by the dtype.
74100
runtime::DataType dtype = expr.dtype();
75101
return GetTypeFromRuntimeDataType(dtype);

0 commit comments

Comments
 (0)