Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,6 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate);

/*!
* \brief The prefetch hint for a buffer
* \param buffer The buffer to be prefetched.
* \param bounds The bounds to be prefetched.
*/
void Prefetch(Buffer buffer, Array<Range> bounds);

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
Expand Down
164 changes: 0 additions & 164 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -335,124 +334,6 @@ class BufferRealize : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode);
};

/*!
* \brief Store value into mult-dimensional array that will be read by the consumer
* of the producer.
*
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa DataProducer
*/
class ProducerStoreNode : public StmtNode {
public:
/*! \brief The producer to store the results into. */
DataProducer producer;
/*! \brief The value to be stored. */
PrimExpr value;
/*! \brief The index arguments of the function. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(value, other->value) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(value);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};

/*!
* \brief Managed reference to ProducerStoreNode.
* \sa ProducerStoreNode
*/
class ProducerStore : public Stmt {
public:
TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
};

/*!
* \brief Annotate the bounds where the data produced by the producer
* need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa DataProducer
*/
class ProducerRealizeNode : public StmtNode {
public:
/*! \brief The producer that produces the data. */
DataProducer producer;
/*! \brief Bounds to be realized. */
Region bounds;
/*! \brief Only realize if condition holds. */
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;
/*! \brief The storage scope associated with this realization. */
String storage_scope;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("storage_scope", &storage_scope);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body) &&
equal(storage_scope, other->storage_scope);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(storage_scope);
}

static constexpr const char* _type_key = "tir.ProducerRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
};

/*!
* \brief Managed reference to ProducerRealizeNode.
* \sa ProducerRealizeNode
*/
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
Expand Down Expand Up @@ -1090,51 +971,6 @@ class While : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
};

/*!
* \brief A prefetch hint for a buffer
*/
class PrefetchNode : public StmtNode {
public:
/*! \brief The function to be prefetched. */
Buffer buffer;
/*! \brief Bounds to be prefetched. */
Array<Range> bounds;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
v->Visit("span", &span);
}

bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(bounds, other->bounds);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(bounds);
}

PrefetchNode() = default;
PrefetchNode(Buffer buffer, Array<Range> bounds, Span span = Span())
: StmtNode(span), buffer(buffer), bounds(bounds) {}

static constexpr const char* _type_key = "tir.Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};

/*!
* \brief Managed reference to PrefetchNode.
* \sa PrefetchNode
*/
class Prefetch : public Stmt {
public:
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
};

/*!
* \brief Representing the region of multi-dimensional buffer access.
*/
Expand Down
12 changes: 0 additions & 12 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -118,9 +115,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
Expand Down Expand Up @@ -164,9 +158,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerStoreNode* op) override;
void VisitStmt_(const ProducerRealizeNode* op) override;
void VisitStmt_(const PrefetchNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const BlockNode* op) override;
Expand Down Expand Up @@ -265,9 +256,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerStoreNode* op) override;
Stmt VisitStmt_(const ProducerRealizeNode* op) override;
Stmt VisitStmt_(const PrefetchNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const EvaluateNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Expand Down
17 changes: 0 additions & 17 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,22 +1316,6 @@ def buffer_store(
)


def prefetch(
buffer: Buffer, # pylint: disable=redefined-outer-name
bounds: List[ir.Range],
) -> None:
"""The prefetch hint for a buffer.

Parameters
----------
buffer : Buffer
The buffer to be prefetched.
bounds : List[Range]
The range to be prefetched.
"""
return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.

Expand Down Expand Up @@ -2144,7 +2128,6 @@ def wrapped(*args, **kwargs):
"launch_thread",
"env_thread",
"buffer_store",
"prefetch",
"evaluate",
"boolean",
"handle",
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@
from .stmt import (
BufferStore,
BufferRealize,
ProducerStore,
Allocate,
AllocateConst,
AttrStmt,
DeclBuffer,
)

from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import SeqStmt
from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc, TensorIntrin, IndexMap
Expand Down
Loading
Loading