Skip to content

Commit 75f710f

Browse files
author
Siyuan Feng
authored
[TIR] Phase out ProducerStore, ProducerRealize and Prefetch (#18057)
1 parent 1cf31bc commit 75f710f

File tree

20 files changed

+4
-596
lines changed

20 files changed

+4
-596
lines changed

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,13 +417,6 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
417417
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
418418
Optional<PrimExpr> predicate);
419419

420-
/*!
421-
* \brief The prefetch hint for a buffer
422-
* \param buffer The buffer to be prefetched.
423-
* \param bounds The bounds to be prefetched.
424-
*/
425-
void Prefetch(Buffer buffer, Array<Range> bounds);
426-
427420
/*!
428421
* \brief Evaluate the input expression.
429422
* \param value The input expression to evaluate.

include/tvm/tir/stmt.h

Lines changed: 0 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include <string>
3030
#include <type_traits>
3131
#include <utility>
32-
#include <vector>
3332

3433
namespace tvm {
3534
namespace tir {
@@ -335,124 +334,6 @@ class BufferRealize : public Stmt {
335334
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode);
336335
};
337336

338-
/*!
339-
* \brief Store value into mult-dimensional array that will be read by the consumer
340-
* of the producer.
341-
*
342-
* \note This node only appears in high-level DSLs that are built on top of the TIR.
343-
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
344-
* this node before TIR transformations.
345-
*
346-
* \sa DataProducer
347-
*/
348-
class ProducerStoreNode : public StmtNode {
349-
public:
350-
/*! \brief The producer to store the results into. */
351-
DataProducer producer;
352-
/*! \brief The value to be stored. */
353-
PrimExpr value;
354-
/*! \brief The index arguments of the function. */
355-
Array<PrimExpr> indices;
356-
357-
void VisitAttrs(AttrVisitor* v) {
358-
v->Visit("producer", &producer);
359-
v->Visit("value", &value);
360-
v->Visit("indices", &indices);
361-
v->Visit("span", &span);
362-
}
363-
364-
bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
365-
return equal(producer, other->producer) && equal(value, other->value) &&
366-
equal(indices, other->indices);
367-
}
368-
369-
void SHashReduce(SHashReducer hash_reduce) const {
370-
hash_reduce(producer);
371-
hash_reduce(value);
372-
hash_reduce(indices);
373-
}
374-
375-
static constexpr const char* _type_key = "tir.ProducerStore";
376-
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
377-
};
378-
379-
/*!
380-
* \brief Managed reference to ProducerStoreNode.
381-
* \sa ProducerStoreNode
382-
*/
383-
class ProducerStore : public Stmt {
384-
public:
385-
TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
386-
Span span = Span());
387-
388-
TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
389-
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
390-
};
391-
392-
/*!
393-
* \brief Annotate the bounds where the data produced by the producer
394-
* need to be written and read in body.
395-
* We will need to allocate space for the corresponding regions.
396-
*
397-
* \note This node only appears in high-level DSLs that are built on top of the TIR.
398-
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
399-
* this node before TIR transformations.
400-
*
401-
* \sa DataProducer
402-
*/
403-
class ProducerRealizeNode : public StmtNode {
404-
public:
405-
/*! \brief The producer that produces the data. */
406-
DataProducer producer;
407-
/*! \brief Bounds to be realized. */
408-
Region bounds;
409-
/*! \brief Only realize if condition holds. */
410-
PrimExpr condition;
411-
/*! \brief The body of realization. */
412-
Stmt body;
413-
/*! \brief The storage scope associated with this realization. */
414-
String storage_scope;
415-
416-
void VisitAttrs(AttrVisitor* v) {
417-
v->Visit("producer", &producer);
418-
v->Visit("bounds", &bounds);
419-
v->Visit("condition", &condition);
420-
v->Visit("body", &body);
421-
v->Visit("storage_scope", &storage_scope);
422-
v->Visit("span", &span);
423-
}
424-
425-
bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
426-
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
427-
equal(condition, other->condition) && equal(body, other->body) &&
428-
equal(storage_scope, other->storage_scope);
429-
}
430-
431-
void SHashReduce(SHashReducer hash_reduce) const {
432-
hash_reduce(producer);
433-
hash_reduce(bounds);
434-
hash_reduce(condition);
435-
hash_reduce(body);
436-
hash_reduce(storage_scope);
437-
}
438-
439-
static constexpr const char* _type_key = "tir.ProducerRealize";
440-
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
441-
};
442-
443-
/*!
444-
* \brief Managed reference to ProducerRealizeNode.
445-
* \sa ProducerRealizeNode
446-
*/
447-
class ProducerRealize : public Stmt {
448-
public:
449-
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
450-
String storage_scope = "", Span span = Span());
451-
452-
TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
453-
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
454-
};
455-
456337
/*!
457338
* \brief Allocate a buffer that can be used in body.
458339
*/
@@ -1090,51 +971,6 @@ class While : public Stmt {
1090971
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
1091972
};
1092973

1093-
/*!
1094-
* \brief A prefetch hint for a buffer
1095-
*/
1096-
class PrefetchNode : public StmtNode {
1097-
public:
1098-
/*! \brief The function to be prefetched. */
1099-
Buffer buffer;
1100-
/*! \brief Bounds to be prefetched. */
1101-
Array<Range> bounds;
1102-
1103-
void VisitAttrs(AttrVisitor* v) {
1104-
v->Visit("buffer", &buffer);
1105-
v->Visit("bounds", &bounds);
1106-
v->Visit("span", &span);
1107-
}
1108-
1109-
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
1110-
return equal(buffer, other->buffer) && equal(bounds, other->bounds);
1111-
}
1112-
1113-
void SHashReduce(SHashReducer hash_reduce) const {
1114-
hash_reduce(buffer);
1115-
hash_reduce(bounds);
1116-
}
1117-
1118-
PrefetchNode() = default;
1119-
PrefetchNode(Buffer buffer, Array<Range> bounds, Span span = Span())
1120-
: StmtNode(span), buffer(buffer), bounds(bounds) {}
1121-
1122-
static constexpr const char* _type_key = "tir.Prefetch";
1123-
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
1124-
};
1125-
1126-
/*!
1127-
* \brief Managed reference to PrefetchNode.
1128-
* \sa PrefetchNode
1129-
*/
1130-
class Prefetch : public Stmt {
1131-
public:
1132-
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
1133-
1134-
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
1135-
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
1136-
};
1137-
1138974
/*!
1139975
* \brief Representing the region of multi-dimensional buffer access.
1140976
*/

include/tvm/tir/stmt_functor.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
9393
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9494
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9595
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
96-
virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97-
virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
98-
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9996
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
10097
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
10198
virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -118,9 +115,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
118115
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
119116
IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
120117
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
121-
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
122-
IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
123-
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
124118
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
125119
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
126120
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
@@ -164,9 +158,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
164158
void VisitStmt_(const BufferStoreNode* op) override;
165159
void VisitStmt_(const BufferRealizeNode* op) override;
166160
void VisitStmt_(const AssertStmtNode* op) override;
167-
void VisitStmt_(const ProducerStoreNode* op) override;
168-
void VisitStmt_(const ProducerRealizeNode* op) override;
169-
void VisitStmt_(const PrefetchNode* op) override;
170161
void VisitStmt_(const SeqStmtNode* op) override;
171162
void VisitStmt_(const EvaluateNode* op) override;
172163
void VisitStmt_(const BlockNode* op) override;
@@ -265,9 +256,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
265256
Stmt VisitStmt_(const BufferStoreNode* op) override;
266257
Stmt VisitStmt_(const BufferRealizeNode* op) override;
267258
Stmt VisitStmt_(const AssertStmtNode* op) override;
268-
Stmt VisitStmt_(const ProducerStoreNode* op) override;
269-
Stmt VisitStmt_(const ProducerRealizeNode* op) override;
270-
Stmt VisitStmt_(const PrefetchNode* op) override;
271259
Stmt VisitStmt_(const SeqStmtNode* op) override;
272260
Stmt VisitStmt_(const EvaluateNode* op) override;
273261
Stmt VisitStmt_(const BlockNode* op) override;

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,22 +1316,6 @@ def buffer_store(
13161316
)
13171317

13181318

1319-
def prefetch(
1320-
buffer: Buffer, # pylint: disable=redefined-outer-name
1321-
bounds: List[ir.Range],
1322-
) -> None:
1323-
"""The prefetch hint for a buffer.
1324-
1325-
Parameters
1326-
----------
1327-
buffer : Buffer
1328-
The buffer to be prefetched.
1329-
bounds : List[Range]
1330-
The range to be prefetched.
1331-
"""
1332-
return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member
1333-
1334-
13351319
def evaluate(value: PrimExpr) -> None:
13361320
"""Evaluate the input expression.
13371321
@@ -2144,7 +2128,6 @@ def wrapped(*args, **kwargs):
21442128
"launch_thread",
21452129
"env_thread",
21462130
"buffer_store",
2147-
"prefetch",
21482131
"evaluate",
21492132
"boolean",
21502133
"handle",

python/tvm/tir/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@
3232
from .stmt import (
3333
BufferStore,
3434
BufferRealize,
35-
ProducerStore,
3635
Allocate,
3736
AllocateConst,
3837
AttrStmt,
3938
DeclBuffer,
4039
)
4140

42-
from .stmt import ProducerRealize, SeqStmt
43-
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
41+
from .stmt import SeqStmt
42+
from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list
4443
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
4544

4645
from .function import PrimFunc, TensorIntrin, IndexMap

0 commit comments

Comments
 (0)