Skip to content

Commit

Permalink
adding storage_scope to ProduerRealize
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 29, 2021
1 parent c586834 commit 496a215
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 7 deletions.
4 changes: 3 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ class Buffer : public ObjectRef {
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", Span span = Span());
String name = "buffer", String storage_scope = "", Span span = Span());

TVM_DLL String GetStorageScope(Var buffer_var);

/*!
* \brief Base node for data producers.
Expand Down
9 changes: 7 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,24 +465,29 @@ class ProducerRealizeNode : public StmtNode {
/*! \brief The body of realization. */
Stmt body;

String storage_scope;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("storage_scope", &storage_scope);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
equal(condition, other->condition) && equal(body, other->body) &&
equal(storage_scope, other->storage_scope);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(storage_scope);
}

static constexpr const char* _type_key = "tir.ProducerRealize";
Expand All @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode {
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
Span span = Span());
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
Expand Down
12 changes: 10 additions & 2 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,20 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
return array;
}

Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, Span span) {
Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String storage_scope,
Span span) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape,
return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span);
}

String GetStorageScope(Var buffer_var) {
auto type = buffer_var->type_annotation;
const auto* ptr_type = type.as<PointerTypeNode>();
ICHECK(ptr_type);
return ptr_type->storage_scope;
}

// Split the given expression w.r.t the add operator
inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
using namespace tir;
Expand Down
3 changes: 2 additions & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// ProducerRealize
ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
Stmt body, Span span) {
Stmt body, String storage_scope, Span span) {
for (size_t i = 0; i < bounds.size(); ++i) {
ICHECK(bounds[i]->min.defined());
ICHECK(bounds[i]->extent.defined());
Expand All @@ -394,6 +394,7 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr
node->condition = std::move(condition);
node->body = std::move(body);
node->span = std::move(span);
node->storage_scope = std::move(storage_scope);
data_ = std::move(node);
}

Expand Down

0 comments on commit 496a215

Please sign in to comment.