Skip to content

Commit c56a646

Browse files
tqchentrevor-m
authored andcommitted
[TIR] Enforce buffer pointer var type to be consistent with dtype. (apache#6317)
Now that we have type_annotation in tir::Var. We should make sure that the type annotation to be consistent with the dtype in Buffer declaration and Allocation. This change allows future passes to directly use the content type information via type_annotation. This PR turns on the enforcement on Buffer and also fixed a few cases for Allocate. A follow up PR need to fix a few more cases in the hybrid script parsing before everything can be made consistent.
1 parent 3894ab5 commit c56a646

File tree

8 files changed

+115
-83
lines changed

8 files changed

+115
-83
lines changed

include/tvm/tir/op.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot);
617617
TVM_DECLARE_INTRIN_BINARY(ldexp);
618618

619619
namespace tir {
620+
621+
/*!
622+
* \brief Check if type is a pointer to a runtime element type.
623+
* \param type The type to be checked.
624+
* \param element_type The corresponding element type.
625+
* \return The check results
626+
*/
627+
inline bool IsPointerType(const Type& type, const DataType& element_type) {
628+
if (!type.defined()) return false;
629+
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
630+
if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
631+
return prim_type->dtype == element_type;
632+
}
633+
}
634+
return false;
635+
}
636+
620637
/*!
621638
* \brief Make a const value with certain data type.
622639
* \param t The target type.

python/tvm/tir/buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from tvm._ffi.base import string_types
2222
from tvm.runtime import Object, convert
23-
from tvm.ir import PrimExpr
23+
from tvm.ir import PrimExpr, PointerType, PrimType
2424
from . import _ffi_api
2525

2626

@@ -241,7 +241,7 @@ def decl_buffer(shape,
241241
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
242242
elem_offset = Var('%s_elem_offset' % name, shape_dtype)
243243
if data is None:
244-
data = Var(name, "handle")
244+
data = Var(name, PointerType(PrimType(dtype)))
245245
return _ffi_api.Buffer(
246246
data, dtype, shape, strides, elem_offset, name, scope,
247247
data_alignment, offset_factor, buffer_type)

python/tvm/tir/ir_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Developer API of IR node builder make function."""
1818
from tvm._ffi.base import string_types
1919
from tvm.runtime import ObjectGeneric, DataType, convert, const
20-
from tvm.ir import container as _container
20+
from tvm.ir import container as _container, PointerType, PrimType
2121

2222
from . import stmt as _stmt
2323
from . import expr as _expr
@@ -325,7 +325,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
325325
buffer : BufferVar
326326
The buffer var representing the buffer.
327327
"""
328-
buffer_var = _expr.Var(name, dtype="handle")
328+
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
329329
if not isinstance(shape, (list, tuple, _container.Array)):
330330
shape = [shape]
331331
if scope:

src/driver/driver_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) {
6969

7070
tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
7171
int data_alignment, int offset_factor, bool compact) {
72-
auto data = tir::Var(name, DataType::Handle());
72+
auto data = tir::Var(name, PointerType(PrimType(dtype)));
7373
bool has_any = false;
7474
if (!compact) {
7575
for (const auto& it : shape) {

src/tir/ir/buffer.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
383383
Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
384384
PrimExpr elem_offset, String name, String scope, int data_alignment,
385385
int offset_factor, BufferType buffer_type) {
386+
CHECK(IsPointerType(data->type_annotation, dtype))
387+
<< "Buffer data field expect to have the right pointer type annotation"
388+
<< " annotation=" << data->type_annotation << ", dtype=" << dtype;
389+
386390
auto n = make_object<BufferNode>();
387391
n->data = std::move(data);
388392
n->dtype = dtype;
393+
389394
n->shape = std::move(shape);
390395
n->strides = std::move(strides);
391396
n->name = std::move(name);

src/tir/ir/stmt.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
263263
// Allocate
264264
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
265265
Stmt body) {
266+
// TODO(tvm-team): Add invariant check to make sure
267+
// IsPointerPType(buffer_var->type_annotation, dtype)
268+
// once we fix the allocate hybrid script printing.
266269
for (size_t i = 0; i < extents.size(); ++i) {
267270
CHECK(extents[i].defined());
268271
CHECK(extents[i].dtype().is_scalar());

src/tir/transforms/bf16_legalize.cc

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) {
172172
* Lower cast between bf16 and fp32
173173
* Lower bf16 FloatImm to int16
174174
*/
175-
class BF16LowerRewriter : StmtExprMutator {
175+
class BF16LowerRewriter : public StmtExprMutator {
176176
public:
177177
BF16LowerRewriter() {}
178178

179-
std::unordered_map<const BufferNode*, Buffer> buffer_remap;
180-
std::unordered_map<const VarNode*, Var> var_remap;
181-
182-
Stmt operator()(Stmt s) { return VisitStmt(s); }
179+
using StmtExprMutator::operator();
183180

184181
PrimExpr VisitExpr_(const CastNode* op) final {
185182
auto op_val = StmtExprMutator::VisitExpr(op->value);
@@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator {
190187
auto uint32_v = Cast(uint32_dtype, op_val);
191188
// to be endian invariant.
192189
return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});
193-
194190
} else if (op->dtype.is_bfloat16()) {
195191
// if is cast_to_bf16, check if op->value is fp32
196192
CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
@@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator {
209205
}
210206

211207
PrimExpr VisitExpr_(const VarNode* op) final {
212-
auto itr = var_remap.find(op);
213-
if (itr != var_remap.end()) {
208+
Var var = GetRef<Var>(op);
209+
210+
auto itr = var_remap_.find(var);
211+
if (itr != var_remap_.end()) {
214212
return itr->second;
213+
} else {
214+
return std::move(var);
215215
}
216-
if (op->dtype.is_bfloat16()) {
217-
CHECK(!op->type_annotation.defined());
218-
auto ret = Var(op->name_hint, op->dtype);
219-
var_remap[op] = ret;
220-
return std::move(ret);
221-
}
222-
return StmtExprMutator::VisitExpr_(op);
223216
}
224217

225218
Stmt VisitStmt_(const AllocateNode* op) final {
226-
Stmt node_holder;
227-
const AllocateNode* newop;
228219
if (op->dtype.is_bfloat16()) {
229-
auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents,
230-
op->condition, op->body);
231-
node_holder = v;
232-
newop = static_cast<const AllocateNode*>(v.operator->());
220+
DataType dtype = DataType::UInt(16, op->dtype.lanes());
221+
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
222+
var_remap_[op->buffer_var] = buffer_var;
223+
return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
233224
} else {
234-
newop = op;
225+
return StmtExprMutator::VisitStmt_(op);
235226
}
236-
return StmtExprMutator::VisitStmt_(newop);
237227
}
238228

239229
Stmt VisitStmt_(const BufferStoreNode* op) final {
240-
auto itr = buffer_remap.find(op->buffer.operator->());
241-
const BufferStoreNode* newop;
242-
BufferStore newop_holder;
243-
if (itr != buffer_remap.end()) {
244-
newop_holder = BufferStore(itr->second, op->value, op->indices);
245-
newop = newop_holder.operator->();
230+
Stmt ret = StmtExprMutator::VisitStmt_(op);
231+
op = ret.as<BufferStoreNode>();
232+
233+
auto it = buffer_remap_.find(op->buffer);
234+
if (it != buffer_remap_.end()) {
235+
return BufferStore(it->second, op->value, op->indices);
246236
} else {
247-
newop = op;
237+
return ret;
248238
}
249-
return StmtExprMutator::VisitStmt_(newop);
250239
}
251240

252241
Stmt VisitStmt_(const AttrStmtNode* op) final {
253-
const AttrStmtNode* newop = op;
254-
Stmt newop_holder;
255-
if (auto buffer = op->node.as<BufferNode>()) {
256-
auto itr = buffer_remap.find(buffer);
257-
if (itr != buffer_remap.end()) {
258-
newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
259-
newop = newop_holder.as<AttrStmtNode>();
242+
Stmt ret = StmtExprMutator::VisitStmt_(op);
243+
op = ret.as<AttrStmtNode>();
244+
245+
if (auto* buffer = op->node.as<BufferNode>()) {
246+
auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
247+
if (it != buffer_remap_.end()) {
248+
return AttrStmt(it->second, op->attr_key, op->value, op->body);
260249
}
261-
} else if (auto buffer = op->node.as<VarNode>()) {
262-
auto itr = var_remap.find(buffer);
263-
if (itr != var_remap.end()) {
264-
newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
265-
newop = newop_holder.as<AttrStmtNode>();
250+
} else if (auto* var = op->node.as<VarNode>()) {
251+
auto it = var_remap_.find(GetRef<Var>(var));
252+
if (it != var_remap_.end()) {
253+
return AttrStmt(it->second, op->attr_key, op->value, op->body);
266254
}
267255
}
268-
return StmtExprMutator::VisitStmt_(newop);
256+
return ret;
269257
}
270258

271259
Stmt VisitStmt_(const BufferRealizeNode* op) final {
272-
auto itr = buffer_remap.find(op->buffer.operator->());
273-
const BufferRealizeNode* newop;
274-
Stmt newop_holder;
275-
if (itr != buffer_remap.end()) {
276-
auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body);
277-
newop_holder = v;
278-
newop = v.operator->();
260+
Stmt ret = StmtExprMutator::VisitStmt_(op);
261+
op = ret.as<BufferRealizeNode>();
262+
263+
auto it = buffer_remap_.find(op->buffer);
264+
if (it != buffer_remap_.end()) {
265+
return BufferRealize(it->second, op->bounds, op->condition, op->body);
279266
} else {
280-
newop = op;
267+
return ret;
268+
}
269+
}
270+
271+
Stmt VisitStmt_(const StoreNode* op) final {
272+
// NOTE: we do not explicit recursivly mutate op->buffer_var
273+
Stmt ret = StmtExprMutator::VisitStmt_(op);
274+
op = ret.as<StoreNode>();
275+
276+
auto it = var_remap_.find(op->buffer_var);
277+
if (it != var_remap_.end()) {
278+
return Store(it->second, op->value, op->index, op->predicate);
279+
} else {
280+
return ret;
281281
}
282-
return StmtExprMutator::VisitStmt_(newop);
283282
}
284283

285284
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
286-
auto itr = buffer_remap.find(op->buffer.operator->());
287-
const BufferLoadNode* newop;
288-
BufferLoad newop_holder;
289-
if (itr != buffer_remap.end()) {
290-
newop_holder = BufferLoad(itr->second, op->indices);
291-
newop = newop_holder.operator->();
285+
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
286+
op = ret.as<BufferLoadNode>();
287+
288+
auto it = buffer_remap_.find(op->buffer);
289+
if (it != buffer_remap_.end()) {
290+
return BufferLoad(it->second, op->indices);
292291
} else {
293-
newop = op;
292+
return ret;
294293
}
295-
return StmtExprMutator::VisitExpr_(newop);
296294
}
297295

298296
PrimExpr VisitExpr_(const LoadNode* op) final {
299-
bool is_bf16 = false;
297+
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
298+
op = ret.as<LoadNode>();
299+
300300
if (op->dtype.is_bfloat16()) {
301-
is_bf16 = true;
302-
}
303-
PrimExpr index = this->VisitExpr(op->index);
304-
PrimExpr predicate = this->VisitExpr(op->predicate);
305-
if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) {
306-
return GetRef<PrimExpr>(op);
301+
auto it = var_remap_.find(op->buffer_var);
302+
CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped";
303+
return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate);
307304
} else {
308-
return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var,
309-
index, predicate);
305+
return ret;
310306
}
311307
}
312308

@@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator {
320316

321317
void AlterBuffers(PrimFuncNode* op) {
322318
std::vector<std::pair<Var, Buffer>> changes;
319+
323320
for (auto& itr : op->buffer_map) {
324321
auto oldbuf = itr.second;
325322
if (oldbuf->dtype.is_bfloat16()) {
326-
auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape,
327-
oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope,
328-
oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type);
329-
buffer_remap[oldbuf.operator->()] = newbuf;
323+
DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
324+
Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
325+
auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
326+
oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
327+
oldbuf->offset_factor, oldbuf->buffer_type);
328+
buffer_remap_[oldbuf] = newbuf;
329+
var_remap_[oldbuf->data] = buffer_var;
330330
changes.emplace_back(itr.first, newbuf);
331+
} else {
332+
changes.emplace_back(itr);
331333
}
332334
}
333-
if (buffer_remap.size() != 0) {
335+
336+
if (buffer_remap_.size() != 0) {
334337
op->buffer_map = Map<Var, Buffer>(changes.begin(), changes.end());
335338
}
336339
}
340+
341+
private:
342+
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
343+
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
337344
};
338345

339346
namespace transform {

src/tir/transforms/storage_flatten.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator {
200200
strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
201201
}
202202

203-
e.buffer =
204-
Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape,
205-
strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault);
203+
e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation),
204+
op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
205+
skey.to_string(), align, 0, kDefault);
206206

207207
buf_map_[key] = e;
208208
Stmt body = this->VisitStmt(op->body);

0 commit comments

Comments
 (0)