diff --git a/include/tvm/ir/memory_pools.h b/include/tvm/ir/memory_pools.h index d4eefdc910b0..3422c1fe719b 100644 --- a/include/tvm/ir/memory_pools.h +++ b/include/tvm/ir/memory_pools.h @@ -27,12 +27,14 @@ #include #include +struct TVMConstantInfo; namespace tvm { /*! * \brief Describes a pool of memory accessible by one or more targets. */ struct PoolInfoNode : public Object { + public: /*! \brief The name of the memory pool */ String pool_name; /*! \brief The expected size hint to be used by the allocator. @@ -40,8 +42,6 @@ struct PoolInfoNode : public Object { * to indicate the pool is not size restricted. */ Integer size_hint_bytes; - /*! \brief The accessibility from each Target */ - Map target_access; // 'rw' or 'ro' /*! \brief The clock frequency of the memory in Hz */ Integer clock_frequency_hz; /*! \brief The read bandwidth in bytes/cycle */ @@ -60,10 +60,12 @@ struct PoolInfoNode : public Object { */ bool is_internal = false; + /*! \brief The targets linked to the pool */ + Array targets; + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pool_name", &pool_name); v->Visit("size_hint_bytes", &size_hint_bytes); - v->Visit("target_access", &target_access); v->Visit("clock_frequency_hz", &clock_frequency_hz); v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle); v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle); @@ -75,8 +77,6 @@ struct PoolInfoNode : public Object { bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && - equal(target_access, other->target_access) && - equal(target_access, other->target_access) && equal(clock_frequency_hz, other->clock_frequency_hz) && equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) && equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) && @@ -89,7 +89,6 @@ struct PoolInfoNode : public Object { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pool_name); hash_reduce(size_hint_bytes); - hash_reduce(target_access); hash_reduce(clock_frequency_hz); hash_reduce(read_bandwidth_bytes_per_cycle); hash_reduce(write_bandwidth_bytes_per_cycle); @@ -100,7 +99,7 @@ struct PoolInfoNode : public Object { } static constexpr const char* _type_key = "ir.PoolInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(PoolInfoNode, Object); }; /*! @@ -129,18 +128,166 @@ static const int kUnknownReadBandwidth = -1; /*! \brief The write bandwidth is not known */ static const int kUnknownWriteBandwidth = -1; +/*! \brief Base class for WorkspacePoolInfo and ConstantPoolInfo */ class PoolInfo : public ObjectRef { - public: - TVM_DLL PoolInfo(String pool_name, Map target_access, - Integer size_hint_bytes = kUnrestrictedPoolSizeHint, + protected: + TVM_DLL PoolInfo(String pool_name, Integer size_hint_bytes = kUnrestrictedPoolSizeHint, Integer clock_frequency_hz = kUnknownClockFrequency, Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth, Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth, Integer read_latency_cycles = 0, Integer write_latency_cycles = 0, Map target_burst_bytes = {}, Bool is_internal = Bool(false)); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); + + public: + TVM_DEFINE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); +}; + +/*! + * \brief Describes a pool of memory properties + */ +struct PoolInfoPropertiesNode : public Object { + /*! \brief The expected size hint to be used by the allocator. + * The size_hint_bytes is set to kUnrestrictedPoolSizeHint + * to indicate the pool is not size restricted. + */ + Integer size_hint_bytes = kUnrestrictedPoolSizeHint; + /*! \brief The clock frequency of the memory in Hz */ + Integer clock_frequency_hz = kUnknownClockFrequency; + /*! \brief The read bandwidth in bytes/cycle */ + Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth; + /*! \brief The write bandwidth in bytes/cycle */ + Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth; + /*! \brief The read latency in cycles */ + Integer read_latency_cycles = 0; + /*! \brief The write latency in cycles */ + Integer write_latency_cycles = 0; + /*! \brief The burst length in bytes for each Target */ + Map target_burst_bytes{}; + /*! \brief Whether pool is internally generated. + * The internal pools will be generated as part of + * the entry point code generation of the executor + */ + bool is_internal = false; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("size_hint_bytes", &size_hint_bytes); + v->Visit("clock_frequency_hz", &clock_frequency_hz); + v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle); + v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle); + v->Visit("read_latency_cycles", &read_latency_cycles); + v->Visit("write_latency_cycles", &write_latency_cycles); + v->Visit("target_burst_bytes", &target_burst_bytes); + v->Visit("is_internal", &is_internal); + } + + bool SEqualReduce(const PoolInfoPropertiesNode* other, SEqualReducer equal) const { + return equal(size_hint_bytes, other->size_hint_bytes) && + equal(clock_frequency_hz, other->clock_frequency_hz) && + equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) && + equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) && + equal(read_latency_cycles, other->read_latency_cycles) && + equal(write_latency_cycles, other->write_latency_cycles) && + equal(target_burst_bytes, other->target_burst_bytes) && + equal(is_internal, other->is_internal); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(size_hint_bytes); + hash_reduce(clock_frequency_hz); + hash_reduce(read_bandwidth_bytes_per_cycle); + hash_reduce(write_bandwidth_bytes_per_cycle); + hash_reduce(read_latency_cycles); + hash_reduce(write_latency_cycles); + hash_reduce(target_burst_bytes); + hash_reduce(is_internal); + } + + static constexpr const char* _type_key = "ir.PoolInfoProperties"; + TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoPropertiesNode, Object); }; +class PoolInfoProperties : public ObjectRef { + public: + TVM_DLL PoolInfoProperties(Integer size_hint_bytes, + Integer clock_frequency_hz = kUnknownClockFrequency, + Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth, + Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth, + Integer read_latency_cycles = 0, Integer write_latency_cycles = 0, + Map target_burst_bytes = {}, + Bool is_internal = Bool(false)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfoProperties, ObjectRef, PoolInfoPropertiesNode); +}; + +/* \brief Represents RW memory area */ +struct WorkspacePoolInfoNode : public PoolInfoNode { + static constexpr const char* _type_key = "ir.WorkspacePoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(WorkspacePoolInfoNode, PoolInfoNode); +}; + +class WorkspacePoolInfo : public PoolInfo { + public: + TVM_DLL WorkspacePoolInfo( + String pool_name, Array targets, + PoolInfoProperties properties = PoolInfoProperties(kUnrestrictedPoolSizeHint)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(WorkspacePoolInfo, PoolInfo, WorkspacePoolInfoNode); +}; + +/* + * \brief The ConstantInfoNode contains numeric literal in RO pool + * Used to initialise RO memory in ConstantPoolInfo + */ +struct ConstantInfoNode : public Object { + String name_hint; + Integer byte_offset; + runtime::NDArray data; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("byte_offset", &byte_offset); + v->Visit("data", &data); + } + + bool SEqualReduce(const ConstantInfoNode* other, SEqualReducer equal) const { + return equal(name_hint, other->name_hint) && equal(byte_offset, other->byte_offset) && + equal(data, other->data); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name_hint); + hash_reduce(byte_offset); + hash_reduce(data); + } + + static constexpr const char* _type_key = "ir.ConstantInfo"; + static constexpr bool _type_has_method_sequal_reduce = true; + static constexpr bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantInfoNode, Object); +}; + +class ConstantInfo : public ObjectRef { + public: + TVM_DLL ConstantInfo(const struct ::TVMConstantInfo* data); + ConstantInfo(String name, Integer byte_offset, runtime::NDArray data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantInfo, ObjectRef, ConstantInfoNode); +}; + +/* \brief ConstantPoolInfoNode represents an RO memory area initialized with + * data from constant_info_array */ +struct ConstantPoolInfoNode : public PoolInfoNode { + Array constant_info_array; + static constexpr const char* _type_key = "ir.ConstantPoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPoolInfoNode, PoolInfoNode); +}; + +class ConstantPoolInfo : public PoolInfo { + public: + TVM_DLL ConstantPoolInfo( + String pool_name, Array targets, Array constant_info_array, + PoolInfoProperties properties = PoolInfoProperties(kUnrestrictedPoolSizeHint)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantPoolInfo, PoolInfo, ConstantPoolInfoNode); +}; + +/* \brief A container for WorkspacePoolInfo objects */ struct WorkspaceMemoryPoolsNode : public Object { Array pools; @@ -162,6 +309,28 @@ class WorkspaceMemoryPools : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(WorkspaceMemoryPools, ObjectRef, WorkspaceMemoryPoolsNode); }; +/* \brief A container for ConstantPoolInfo objects */ +struct ConstantMemoryPoolsNode : public Object { + Array pools; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pools", &pools); } + + bool SEqualReduce(const ConstantMemoryPoolsNode* other, SEqualReducer equal) const { + return equal(pools, other->pools); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pools); } + + static constexpr const char* _type_key = "ir.ConstantMemoryPools"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantMemoryPoolsNode, Object); +}; + +class ConstantMemoryPools : public ObjectRef { + public: + TVM_DLL ConstantMemoryPools(Array pools); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantMemoryPools, ObjectRef, ConstantMemoryPoolsNode); +}; + } // namespace tvm #endif // TVM_IR_MEMORY_POOLS_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index e32ddb716bd5..b78f16a84f02 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -506,6 +506,15 @@ constexpr const char* kRuntime = "runtime"; */ constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools"; +/*! + * \brief constant memory pools of the module + * + * Type: ConstantMemoryPools + * + * \sa tvm::ConstantMemoryPools + */ +constexpr const char* kConstantMemoryPools = "constant_memory_pools"; + /* * \brief Module attribute for tir constants */ diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index 640d52ff80e7..f921f3e39c60 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -24,6 +24,7 @@ #ifndef TVM_RUNTIME_METADATA_H_ #define TVM_RUNTIME_METADATA_H_ +#include #include #include #include @@ -45,16 +46,10 @@ namespace metadata { * Should be populated into the `version` field of all TVMMetadata. */ static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; -} // namespace metadata -} // namespace runtime -} // namespace tvm - -namespace tvm { -namespace runtime { -namespace metadata { class Metadata; class TensorInfo; +class ConstantInfoMetadata; class MetadataNode : public MetadataBaseNode { public: @@ -66,10 +61,12 @@ class MetadataNode : public MetadataBaseNode { ArrayAccessor inputs(); inline int64_t num_outputs() const { return data_->num_outputs; } ArrayAccessor outputs(); - inline int64_t num_pools() const { return data_->num_pools; } - ArrayAccessor pools(); + inline int64_t num_workspace_pools() const { return data_->num_workspace_pools; } + ArrayAccessor workspace_pools(); inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } const struct ::TVMMetadata* data() const { return data_; } + ArrayAccessor constant_pools(); + inline int64_t num_constant_pools() const { return data_->num_constant_pools; } TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); private: @@ -107,6 +104,37 @@ class TensorInfo : public MetadataBase { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); }; +class ConstantInfoMetadataNode : public MetadataBaseNode { + public: + explicit ConstantInfoMetadataNode(const struct ::TVMConstantInfo* data) : data_{data} {} + // This name should match TVMConstantInfo after processing + static constexpr const char* _type_key = "metadata.ConstantInfoNode"; + const char* get_c_struct_name() const override; + inline ::tvm::runtime::String name_hint() const { + return ::tvm::runtime::String(data_->name_hint); + } + inline size_t byte_offset() const { return data_->byte_offset; } + inline ::tvm::runtime::NDArray data() const { + ::tvm::runtime::NDArray ndarray; + if (data_->data_len) { + dmlc::MemoryFixedSizeStream bytes(const_cast(data_->data_bytes), data_->data_len); + ndarray.Load(&bytes); + } + return ndarray; + } + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantInfoMetadataNode, MetadataBaseNode); + + protected: + const struct ::TVMConstantInfo* data_; +}; + +class ConstantInfoMetadata : public MetadataBase { + public: + explicit ConstantInfoMetadata(const struct ::TVMConstantInfo* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantInfoMetadata, MetadataBase, + ConstantInfoMetadataNode); +}; + } // namespace metadata } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/metadata_types.h b/include/tvm/runtime/metadata_types.h index 36d690cf34bc..5d828843e2b8 100644 --- a/include/tvm/runtime/metadata_types.h +++ b/include/tvm/runtime/metadata_types.h @@ -54,13 +54,18 @@ struct TVMMetadata { const struct TVMTensorInfo* outputs; /*! \brief Number of elements in `outputs` array. */ int64_t num_outputs; - /*! \brief Memory Pools needed by the AOT main function. + /*! \brief Workspace Memory Pools needed by the AOT main function. * The order of the elements is the same as in the arguments to run_model. That is to say, - * this array specifies the last `num_pools` arguments to run_model. + * this array specifies the last `num_workspace_pools` arguments to run_model. */ - const struct TVMTensorInfo* pools; - /*! \brief Number of elements in `pools` array. */ - int64_t num_pools; + const struct TVMTensorInfo* workspace_pools; + /*! \brief Number of elements in `workspace_pools` array. */ + int64_t num_workspace_pools; + /*! \brief Constant pools needed by the AOT main function. + */ + const struct TVMConstantInfo* constant_pools; + /*! \brief Number of elements in `constant_pools` array. */ + int64_t num_constant_pools; /*! \brief Name of the model, as passed to tvm.relay.build. */ const char* mod_name; }; @@ -82,6 +87,21 @@ struct TVMTensorInfo { DLDataType dtype; }; +/*! + * \brief Describes one constant argument to `run_model`. + * + */ +struct TVMConstantInfo { + /*! \brief Name of the constant */ + const char* name_hint; + /*! \brief Offset in bytes of the constant */ + int64_t byte_offset; + /*! \brief length of the data_bytes field */ + int64_t data_len; + /*! \brief data bytes of serialized NDArray */ + const void* data_bytes; +}; + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 288ed9d609ab..fc02550c7e25 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -675,7 +675,9 @@ class AllocateConst : public Stmt { * create AllocateConstNode with irmod_storage_idx or data */ TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, Span span = Span()); + ObjectRef data_or_idx, Stmt body, + Map annotations = Map(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); }; diff --git a/include/tvm/tir/usmp/algorithms.h b/include/tvm/tir/usmp/algorithms.h index e2f2b6fb73f3..54431b59d21c 100644 --- a/include/tvm/tir/usmp/algorithms.h +++ b/include/tvm/tir/usmp/algorithms.h @@ -53,6 +53,17 @@ Map GreedyBySize(const Array& buffer_inf */ Map GreedyByConflicts(const Array& buffer_info_arr, const Integer& memory_pressure); +/*! + *\brief The Hill-Climb algoritm to plan memory + * + * This will perform an attempt to utilize probabalistic approach to memory + * allocation. Typically better than greedy family, but quite slow due to large + * number of iterations. + * + * \return A Map of BufferInfo objects and their associated PoolAllocation + */ +Map HillClimb(const Array& buffer_info_arr, + const Integer& memory_pressure); /*! * \brief The Hill-Climb algorithm to plan memory diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index f7858acb1779..5b3b44ff7e04 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -48,7 +48,6 @@ constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io"; namespace tir { namespace usmp { - /*! * \brief A special kind to distinguish between I/O tensors to the model * and intermediate tensors of the model @@ -163,9 +162,9 @@ class BufferInfoAnalysis : public ObjectRef { * \brief The pool allocation produced after the USMP algorithm */ struct PoolAllocationNode : public Object { - /*! \brief The assigned PoolInfo object */ + /*! \brief The assigned WorkspacePoolInfo or ConstantPoolInfo object */ PoolInfo pool_info; - /*! \brief The byte offset where the tensor is supposed to be placed within the pool*/ + /*! \brief The byte offset within the pool*/ Integer byte_offset; void VisitAttrs(tvm::AttrVisitor* v) { @@ -236,7 +235,7 @@ class AllocatedPoolInfo : public ObjectRef { * * \param buffer_info_map IR-bound BufferInfo map */ -Array CreateArrayBufferInfo(const Map& buffer_info_map); +Array ConvertToArrayOfBufferInfo(const Map& buffer_info_map); /*! * \brief Calculate workspace required to execute a IRModule with main expressed in TIR @@ -271,6 +270,13 @@ static constexpr const char* kOutputTensorAllocate = "output_tensor"; */ Integer CalculateExtentsSize(const AllocateNode* op); +/*! + * \brief Calculate the size of the extents in bytes + * + * \param op the allocate const node + */ +Integer CalculateExtentsSize(const AllocateConstNode* op); + /*! * \brief Joins the Stmt nodes with PoolAllocation objects * diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index ac3acdde3088..5b6fbe7b2546 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -44,7 +44,11 @@ from .ir import instrument from .ir import container from .ir import PoolInfo +from .ir import WorkspacePoolInfo +from .ir import ConstantPoolInfo +from .ir import PoolInfoProperties from .ir import WorkspaceMemoryPools +from .ir import ConstantMemoryPools from . import ir # tvm.tir diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 928631ce10de..4e847c0310a4 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -30,7 +30,14 @@ from .module import IRModule from .attrs import Attrs, DictAttrs, make_node from .container import Array, Map -from .memory_pools import PoolInfo, WorkspaceMemoryPools +from .memory_pools import ( + PoolInfo, + WorkspacePoolInfo, + ConstantPoolInfo, + WorkspaceMemoryPools, + ConstantMemoryPools, + PoolInfoProperties, +) from . import transform from . import instrument diff --git a/python/tvm/ir/memory_pools.py b/python/tvm/ir/memory_pools.py index 6fa6bb41280e..0186a89f8413 100644 --- a/python/tvm/ir/memory_pools.py +++ b/python/tvm/ir/memory_pools.py @@ -27,18 +27,20 @@ class PoolInfo(Object): """PoolInfo object holds information related to memory pools where the statically sized allocate nodes will pooled into. + This is a base class for WorkspacePoolInfo and ConstantPoolInfo. + """ - Parameters - ---------- - pool_name : str - The name of the memory pool + def __init__(self): + pass - target_access : Dict[Target, str] - A dictionary where keys describe which targets could - access the pool where value could take the values : - a) "rw" : read-write access - b) "ro" : write-only acesss +@register_object("ir.PoolInfoProperties") +class PoolInfoProperties(Object): + """PoolInfo object holds information related to memory pools + where the statically sized allocate nodes will pooled into. + + Parameters + ---------- size_hint_bytes : Optional[int] The expected size hint to be used by the allocator. The default value would be -1 which means the pool @@ -74,34 +76,21 @@ class PoolInfo(Object): """ - # The string parameter to indicate read and write access to a pool - # This needs to be kept in sync with kTargetPoolReadWriteAccess in - # include/tvm/ir/memory_pools.h - READ_WRITE_ACCESS = "rw" - # The string parameter to indicate read only access to a pool - # This needs to be kept in sync with kTargetPoolReadOnlyAccess in - # include/tvm/ir/memory_pools.h - READ_ONLY_ACCESS = "ro" - def __init__( self, - pool_name: str, - target_access, # Dict[Target, str] size_hint_bytes: Optional[int] = -1, clock_frequency_hz: Optional[int] = -1, read_bandwidth_bytes_per_cycle: Optional[int] = -1, write_bandwidth_bytes_per_cycle: Optional[int] = -1, read_latency_cycles: Optional[int] = 0, write_latency_cycles: Optional[int] = 0, - target_burst_bytes=None, # Optional[Union[Dict[target.Target, int], None]] + target_burst_bytes=None, ): if not target_burst_bytes: target_burst_bytes = dict() self.__init_handle_by_constructor__( - _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member - pool_name, - target_access, + _ffi_api.PoolInfoProperties, # type: ignore # pylint: disable=no-member size_hint_bytes, clock_frequency_hz, read_bandwidth_bytes_per_cycle, @@ -112,15 +101,90 @@ def __init__( ) +@register_object("ir.WorkspacePoolInfo") +class WorkspacePoolInfo(PoolInfo): + """WorkspacePoolInfo object holds information related to RW memory pools + where the statically sized allocate nodes will pooled into. + + Parameters + ---------- + pool_name : str + The name of the memory pool + + targets : list[Target] + A list of targets which could access the pool + + pool_info_properties : PoolInfoProperties + The properties of the pool. + """ + + def __init__( + self, + pool_name: str, + targets, + pool_info_properties=None, + ): + super().__init__() + + if pool_info_properties is None: + pool_info_properties = PoolInfoProperties() + + self.__init_handle_by_constructor__( + _ffi_api.WorkspacePoolInfo, # type: ignore # pylint: disable=no-member + pool_name, + targets, + pool_info_properties, + ) + + +@register_object("ir.ConstantPoolInfo") +class ConstantPoolInfo(PoolInfo): + """ConstantPoolInfo object holds information related to RO memory pools + where the statically sized allocate nodes are pooled into. + + Parameters + ---------- + pool_name : str + The name of the memory pool + + targets : list[Target] + describes which targets could access the pool + + pool_info_properties : PoolInfoProperties + The properties of the pool. + """ + + def __init__( + self, + pool_name: str, + targets, # list[Target] + constant_info_arr=None, # list[ConstantInfo] + pool_info_properties=None, + ): + super().__init__() + + if constant_info_arr is None: + constant_info_arr = [] + if pool_info_properties is None: + pool_info_properties = PoolInfoProperties() + self.__init_handle_by_constructor__( + _ffi_api.ConstantPoolInfo, # type: ignore # pylint: disable=no-member + pool_name, + targets, + constant_info_arr, + pool_info_properties, + ) + + @register_object("ir.WorkspaceMemoryPools") class WorkspaceMemoryPools(Object): - """This object contains a list of PoolInfo objects to be used as + """This object contains a list of WorkspacePoolInfo objects to be used as workspace memory in the compilation Parameters ---------- - pools : List[PoolInfo] - The list of PoolInfo objects to be used with the compilation + pools : List[WorkspacePoolInfo] + The list of ConstantPoolInfo objects to be used with the compilation """ def __init__( @@ -130,3 +194,23 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.WorkspaceMemoryPools, pools # type: ignore # pylint: disable=no-member ) + + +@register_object("ir.ConstantMemoryPools") +class ConstantMemoryPools(Object): + """This object contains a list of ConstantPoolInfo objects to be used as + read-only memory in the compilation + + Parameters + ---------- + pools : List[ConstantPoolInfo] + The list of ConstantPoolInfo objects to be used with the compilation + """ + + def __init__( + self, + pools: List[ConstantPoolInfo], + ): + self.__init_handle_by_constructor__( + _ffi_api.ConstantMemoryPools, pools # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 9eeb20f5f1ce..23892554cf12 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -80,6 +80,7 @@ def build( executor=Executor("graph"), runtime=Runtime("cpp"), workspace_memory_pools=None, + constant_memory_pools=None, params=None, mod_name=None, ): @@ -111,8 +112,13 @@ def build( Defaults to "cpp" if no runtime specified. workspace_memory_pools : Optional[WorkspaceMemoryPools] - The object that contains an Array of PoolInfo objects - that hold properties of workspace pools that could be + The object that contains an Array of WorkspacePoolInfo objects + that hold properties of read-write workspace pools that could be + used by the inference. + + constant_memory_pools : Optional[ConstantMemoryPools] + The object that contains an Array of ConstantPoolInfo objects + that hold properties of read-only memory pools that could be used by the inference. params : dict of str to NDArray @@ -133,7 +139,6 @@ def build( params : dict The parameters of the final graph. """ - raw_targets = Target.canon_multi_target_and_host(target, target_host) # Setup the params. if params: @@ -151,7 +156,16 @@ def build( mod_name = mangle_module_name(mod_name) - self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, mod_name) + self._build( + mod, + target, + target_host, + executor, + runtime, + workspace_memory_pools, + constant_memory_pools, + mod_name, + ) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts @@ -328,6 +342,7 @@ def build( executor=Executor("graph"), runtime=Runtime("cpp"), workspace_memory_pools=None, + constant_memory_pools=None, params=None, mod_name="default", ): @@ -357,8 +372,13 @@ def build( Defaults to "cpp" if no runtime specified. workspace_memory_pools : Optional[WorkspaceMemoryPools] - The object that contains an Array of PoolInfo objects - that hold properties of workspace pools that could be + The object that contains an Array of WorkspacePoolInfo objects + that hold properties of read-write workspace pools that could be + used by the inference. + + constant_memory_pools : Optional[ConstantMemoryPools] + The object that contains an Array of ConstantPoolInfo objects + that hold properties of read-only pools that could be used by the inference. params : dict of str to NDArray @@ -420,6 +440,7 @@ def build( executor=executor, runtime=runtime, workspace_memory_pools=workspace_memory_pools, + constant_memory_pools=constant_memory_pools, mod_name=mod_name, ) func_metadata = bld_mod.get_function_metadata() diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 85882055d02f..76fbf26eea31 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -176,12 +176,20 @@ class AllocateConst(WithScopeHandler): """ def __init__(self): - def allocate_const(raw_data, dtype, shape, span=None): + def allocate_const(raw_data, dtype, shape, annotations=None, span=None): list_data = [] for i in raw_data: list_data.append(i.value) nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) - n = tvm.tir.AllocateConst(self.buffer.data, dtype, shape, nd_data, self.body, span=span) + n = tvm.tir.AllocateConst( + self.buffer.data, + dtype, + shape, + nd_data, + self.body, + annotations=annotations, + span=span, + ) return n super().__init__(allocate_const, concise_scope=True, def_symbol=True) @@ -209,7 +217,7 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer(data, dtype, shape, span: Span = None): + def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None): """Setup buffer var for a given type.""" self.buffer = tvm.tir.decl_buffer( shape=shape, diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index f8f170366ac5..583286bf273a 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -587,11 +587,13 @@ def compile_models( interface_api: str, use_unpacked_api: bool, workspace_byte_alignment: int = 8, + constant_byte_alignment: int = 8, enable_op_fusion: bool = True, pass_config: Dict[str, Any] = None, use_runtime_executor: bool = True, target: tvm.target.Target = tvm.target.Target("c"), workspace_memory_pools=None, + constant_memory_pools=None, schedule_name: str = None, ) -> List[AOTCompiledTestModel]: """ @@ -605,6 +607,7 @@ def compile_models( "aot", { "workspace-byte-alignment": workspace_byte_alignment, + "constant-byte-alignment": constant_byte_alignment, "interface-api": interface_api, "unpacked-api": use_unpacked_api, }, @@ -632,6 +635,7 @@ def compile_models( executor=executor, runtime=runtime, workspace_memory_pools=workspace_memory_pools, + constant_memory_pools=constant_memory_pools, params=model.params, mod_name=model.name, ) @@ -658,6 +662,7 @@ def compile_models( executor=executor, runtime=runtime, workspace_memory_pools=workspace_memory_pools, + constant_memory_pools=constant_memory_pools, params=model.params, mod_name=model.name, ) @@ -683,6 +688,7 @@ def run_and_check( interface_api: str, debug_calculated_workspaces=False, workspace_byte_alignment=8, + constant_byte_alignment=8, data_linkage: AOTDataLinkage = None, test_dir: str = None, verbose: bool = False, @@ -694,7 +700,10 @@ def run_and_check( """ def run_and_check_body(base_path): - cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " + cflags = ( + f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " + f" -DTVM_RUNTIME_CONST_ALLOC_ALIGNMENT_BYTES={constant_byte_alignment} " + ) # The calculated workspaces will not account for stack allocator tags used for debugging if debug_calculated_workspaces: cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK " @@ -830,6 +839,7 @@ def compile_and_run( use_unpacked_api: bool, debug_calculated_workspaces: bool = False, workspace_byte_alignment: int = 8, + constant_byte_alignment: int = 8, enable_op_fusion: bool = True, data_linkage: AOTDataLinkage = None, use_runtime_executor: bool = True, @@ -858,6 +868,7 @@ def compile_and_run( interface_api=interface_api, use_unpacked_api=use_unpacked_api, workspace_byte_alignment=workspace_byte_alignment, + constant_byte_alignment=constant_byte_alignment, enable_op_fusion=enable_op_fusion, pass_config=runner.pass_config, use_runtime_executor=use_runtime_executor, @@ -871,6 +882,7 @@ def compile_and_run( interface_api=interface_api, debug_calculated_workspaces=debug_calculated_workspaces, workspace_byte_alignment=workspace_byte_alignment, + constant_byte_alignment=constant_byte_alignment, data_linkage=data_linkage, test_dir=test_dir, verbose=verbose, @@ -897,7 +909,9 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): else: main = mod["main"] if main.attrs is None or main.attrs["output_tensor_names"] is None: - output_tensor_names = ["output" if i == 0 else f"output{i+1}" for i in range(output_count)] + output_tensor_names = ( + ["output"] if output_count == 1 else [f"output{i}" for i in range(output_count)] + ) else: output_tensor_names = main.attrs["output_tensor_names"] diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 0e91f8841313..7fc73ef4c436 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -182,6 +182,25 @@ def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> return _ffi_api.calculate_workspace_bytes(func, workspace_byte_alignment) # type: ignore +def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> int: + """Calculate the constant size in bytes needed by the TIR allocates inside the TIR + PrimFunc. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to be detected. + constant_byte_alignment : int + The byte alignment required for each tensor + + Returns + ------- + result : int + Workspace size in bytes. + """ + return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore + + def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: """Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 9734f7ae2bc9..301bfa73c818 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -364,13 +364,16 @@ class AllocateConst(Stmt): body : Stmt The body statement. + annotations : Optional[Map] + Additional annotations about the allocation. + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, buffer_var, dtype, extents, data_or_idx, body, span=None): + def __init__(self, buffer_var, dtype, extents, data_or_idx, body, annotations=None, span=None): self.__init_handle_by_constructor__( - _ffi_api.AllocateConst, buffer_var, dtype, extents, data_or_idx, body, span + _ffi_api.AllocateConst, buffer_var, dtype, extents, data_or_idx, body, annotations, span ) diff --git a/src/ir/memory_pools.cc b/src/ir/memory_pools.cc index 5cf0035c90b2..f5064af207cc 100644 --- a/src/ir/memory_pools.cc +++ b/src/ir/memory_pools.cc @@ -27,15 +27,13 @@ namespace tvm { -PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, - Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, - Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, - Integer write_latency_cycles, Map target_burst_bytes, - Bool is_internal) { +PoolInfo::PoolInfo(String pool_name, Integer size_hint_bytes, Integer clock_frequency_hz, + Integer read_bandwidth_bytes_per_cycle, Integer write_bandwidth_bytes_per_cycle, + Integer read_latency_cycles, Integer write_latency_cycles, + Map target_burst_bytes, Bool is_internal) { auto poolinfo_node = make_object(); poolinfo_node->pool_name = pool_name; poolinfo_node->size_hint_bytes = size_hint_bytes; - poolinfo_node->target_access = target_access; poolinfo_node->clock_frequency_hz = clock_frequency_hz; poolinfo_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle; poolinfo_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle; @@ -47,21 +45,12 @@ PoolInfo::PoolInfo(String pool_name, Map target_access, Integer } TVM_REGISTER_NODE_TYPE(PoolInfoNode); -TVM_REGISTER_GLOBAL("ir.PoolInfo") - .set_body_typed([](String pool_name, Map target_access, Integer size_hint_bytes, - Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, - Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, - Integer write_latency_cycles, Map target_burst_bytes) { - return PoolInfo(pool_name, target_access, size_hint_bytes, clock_frequency_hz, - read_bandwidth_bytes_per_cycle, write_bandwidth_bytes_per_cycle, - read_latency_cycles, write_latency_cycles, target_burst_bytes, Bool(false)); - }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PoolInfoNode(\n" - << " pool_name=" << node->pool_name << ",\n target_access=" << node->target_access + << " pool_name=" << node->pool_name << ",\n size_hint_bytes=" << node->size_hint_bytes << ",\n clock_frequency_hz=" << node->clock_frequency_hz << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle @@ -71,6 +60,151 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n target_burst_bytes=" << node->target_burst_bytes << ")"; }); +PoolInfoProperties::PoolInfoProperties(Integer size_hint_bytes, Integer clock_frequency_hz, + Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, + Integer read_latency_cycles, Integer write_latency_cycles, + Map target_burst_bytes, Bool is_internal) { + auto poolinfo_properties_node = make_object(); + poolinfo_properties_node->size_hint_bytes = size_hint_bytes; + poolinfo_properties_node->clock_frequency_hz = clock_frequency_hz; + poolinfo_properties_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle; + poolinfo_properties_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle; + poolinfo_properties_node->read_latency_cycles = read_latency_cycles; + poolinfo_properties_node->write_latency_cycles = write_latency_cycles; + poolinfo_properties_node->target_burst_bytes = target_burst_bytes; + poolinfo_properties_node->is_internal = is_internal; + data_ = std::move(poolinfo_properties_node); +} + +TVM_REGISTER_NODE_TYPE(PoolInfoPropertiesNode); +TVM_REGISTER_GLOBAL("ir.PoolInfoProperties") + .set_body_typed([](Integer size_hint_bytes, Integer clock_frequency_hz, + Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, + Integer write_latency_cycles, Map target_burst_bytes) { + return PoolInfoProperties(size_hint_bytes, clock_frequency_hz, read_bandwidth_bytes_per_cycle, + write_bandwidth_bytes_per_cycle, read_latency_cycles, + write_latency_cycles, target_burst_bytes, Bool(false)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PoolInfoPropertiesNode(\n" + << " size_hint_bytes=" << node->size_hint_bytes + << ",\n clock_frequency_hz=" << node->clock_frequency_hz + << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle + << ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle + << ",\n read_latency_cycles=" << node->read_latency_cycles + << ",\n write_latency_cycles=" << node->write_latency_cycles + << ",\n target_burst_bytes=" << node->target_burst_bytes << ")"; + }); + +WorkspacePoolInfo::WorkspacePoolInfo(String pool_name, Array targets, + PoolInfoProperties properties) { + auto poolinfo_node = make_object(); + poolinfo_node->pool_name = pool_name; + poolinfo_node->size_hint_bytes = properties->size_hint_bytes; + poolinfo_node->targets = targets; + poolinfo_node->clock_frequency_hz = properties->clock_frequency_hz; + poolinfo_node->read_bandwidth_bytes_per_cycle = properties->read_bandwidth_bytes_per_cycle; + poolinfo_node->write_bandwidth_bytes_per_cycle = properties->write_bandwidth_bytes_per_cycle; + poolinfo_node->read_latency_cycles = properties->read_latency_cycles; + poolinfo_node->write_latency_cycles = properties->write_latency_cycles; + poolinfo_node->target_burst_bytes = properties->target_burst_bytes; + poolinfo_node->is_internal = properties->is_internal; + data_ = std::move(poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(WorkspacePoolInfoNode); +TVM_REGISTER_GLOBAL("ir.WorkspacePoolInfo") + .set_body_typed([](String pool_name, Array targets, PoolInfoProperties properties) { + return WorkspacePoolInfo(pool_name, targets, properties); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "WorkspacePoolInfoNode(\n" + << " pool_name=" << node->pool_name << ",\n targets=" << node->targets + << ",\n size_hint_bytes=" << node->size_hint_bytes + << ",\n clock_frequency_hz=" << node->clock_frequency_hz + << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle + << ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle + << ",\n read_latency_cycles=" << node->read_latency_cycles + << ",\n write_latency_cycles=" << node->write_latency_cycles + << ",\n target_burst_bytes=" << node->target_burst_bytes + << ",\n is_internal=" << node->is_internal << ")" + << "\n"; + }); + +ConstantInfo::ConstantInfo(String name_hint, Integer byte_offset, runtime::NDArray data) { + auto constant_info_node = make_object(); + constant_info_node->name_hint = name_hint; + constant_info_node->byte_offset = byte_offset; + constant_info_node->data = data; + data_ = std::move(constant_info_node); +} + +TVM_REGISTER_NODE_TYPE(ConstantInfoNode); +TVM_REGISTER_GLOBAL("ir.ConstantInfo") + .set_body_typed([](String name_hint, Integer byte_offset, runtime::NDArray data) { + return ConstantInfo(name_hint, byte_offset, data); + }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstantInfoNode(\n" + << "name_hint=" << node->name_hint << ",\n byte_offset=" << node->byte_offset + << ",\n data=" << node->data << ")"; + }); + +ConstantPoolInfo::ConstantPoolInfo(String pool_name, Array targets, + Array constant_info_array, + PoolInfoProperties properties) { + auto constant_poolinfo_node = make_object(); + constant_poolinfo_node->pool_name = pool_name; + constant_poolinfo_node->constant_info_array = constant_info_array; + constant_poolinfo_node->targets = targets; + + constant_poolinfo_node->size_hint_bytes = properties->size_hint_bytes; + constant_poolinfo_node->clock_frequency_hz = properties->clock_frequency_hz; + constant_poolinfo_node->read_bandwidth_bytes_per_cycle = + properties->read_bandwidth_bytes_per_cycle; + constant_poolinfo_node->write_bandwidth_bytes_per_cycle = + properties->write_bandwidth_bytes_per_cycle; + constant_poolinfo_node->read_latency_cycles = properties->read_latency_cycles; + constant_poolinfo_node->write_latency_cycles = properties->write_latency_cycles; + constant_poolinfo_node->target_burst_bytes = properties->target_burst_bytes; + constant_poolinfo_node->is_internal = properties->is_internal; + data_ = std::move(constant_poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(ConstantPoolInfoNode); +TVM_REGISTER_GLOBAL("ir.ConstantPoolInfo") + .set_body_typed([](String pool_name, Array targets, + Array constant_info_array, PoolInfoProperties properties) { + return ConstantPoolInfo(pool_name, targets, constant_info_array, properties); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstantPoolInfoNode(\n" + << " pool_name=" << node->pool_name << ",\n targets=" << node->targets + << ",\n constant_info_array=" << node->constant_info_array + << ",\n size_hint_bytes=" << node->size_hint_bytes + << ",\n clock_frequency_hz=" << node->clock_frequency_hz + << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle + << ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle + << ",\n read_latency_cycles=" << node->read_latency_cycles + << ",\n write_latency_cycles=" << node->write_latency_cycles + << ",\n target_burst_bytes=" << node->target_burst_bytes + << ",\n is_internal=" << node->is_internal << ")" + << "\n"; + }); + WorkspaceMemoryPools::WorkspaceMemoryPools(Array pools) { auto workspace_memory_pools_node = make_object(); workspace_memory_pools_node->pools = pools; @@ -89,4 +223,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "pools=" << node->pools << ")"; }); +ConstantMemoryPools::ConstantMemoryPools(Array pools) { + auto constant_memory_pools_node = make_object(); + constant_memory_pools_node->pools = pools; + data_ = std::move(constant_memory_pools_node); +} + +TVM_REGISTER_NODE_TYPE(ConstantMemoryPoolsNode); +TVM_REGISTER_GLOBAL("ir.ConstantMemoryPools").set_body_typed([](Array pools) { + return ConstantMemoryPools(pools); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstantMemoryPoolsNode(\n" + << "pools=" << node->pools << ")"; + }); } // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 381cfa0c9d1c..5938417128e0 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -876,9 +876,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { */ Map CalculateWorkspaceSizes( const IRModule& lowered_mod, const Map& function_metadata) { - Executor executor_config = lowered_mod->GetAttr(tvm::attr::kExecutor).value(); - Integer workspace_byte_alignment = - executor_config->GetAttr("workspace-byte-alignment").value_or(16); + Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(lowered_mod); Map updated_function_metadata; for (const auto& kv : lowered_mod->functions) { GlobalVar global_var = kv.first; @@ -905,9 +903,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { */ IRModule PlanMemoryWithUSMP(const IRModule& mod) { VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod); - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - Integer workspace_byte_alignment = - executor_config->GetAttr("workspace-byte-alignment").value_or(16); + Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); IRModule lowered_mod = mod->ShallowCopy(); lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod); function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); @@ -918,16 +914,22 @@ class AOTExecutorCodegen : public MixedModeVisitor { main_func_info->workspace_sizes.clear(); if (allocated_pool_infos) { for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - for (const auto& kv : allocated_pool_info->pool_info->target_access) { - Target tgt = kv.first; + for (const auto& tgt : allocated_pool_info->pool_info->targets) { VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size " << allocated_pool_info->allocated_size->value; - if (main_func_info->workspace_sizes.find(tgt) == main_func_info->workspace_sizes.end()) { - main_func_info->workspace_sizes.Set(tgt, allocated_pool_info->allocated_size); + size_t size = allocated_pool_info->allocated_size->value; + if (allocated_pool_info->pool_info->IsInstance()) { + size += main_func_info->constant_sizes.count(tgt) + ? main_func_info->constant_sizes[tgt]->value + : 0; + main_func_info->constant_sizes.Set(tgt, size); + } else if (allocated_pool_info->pool_info->IsInstance()) { + size += main_func_info->workspace_sizes.count(tgt) + ? main_func_info->workspace_sizes[tgt]->value + : 0; + main_func_info->workspace_sizes.Set(tgt, size); } else { - main_func_info->workspace_sizes.Set(tgt, - main_func_info->workspace_sizes[tgt]->value + - allocated_pool_info->allocated_size->value); + LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey(); } } } @@ -940,9 +942,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { * \brief Run StorageRewrite to plan memory for lowered IRModule. */ IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { - Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - Integer workspace_byte_alignment = - executor_config->GetAttr("workspace-byte-alignment").value_or(16); + Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); IRModule lowered_mod = mod->ShallowCopy(); function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); // Running StorageRewrite just on the main function @@ -966,6 +966,22 @@ class AOTExecutorCodegen : public MixedModeVisitor { return lowered_mod; } + /*! + * \brief Gets module workspace alignment from supplied executor or defaults to 16 + */ + Integer GetModuleWorkspaceByteAlignment(const IRModule& mod) { + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); + return executor_config->GetAttr("workspace-byte-alignment").value_or(16); + } + + /*! + * \brief Gets module constant alignment from supplied executor or defaults to 16 + */ + Integer GetModuleConstantByteAlignment(const IRModule& mod) { + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); + return executor_config->GetAttr("constant-byte-alignment").value_or(16); + } + protected: /*! \brief mod */ runtime::Module* mod_; @@ -1026,14 +1042,14 @@ class AOTExecutorCodegen : public MixedModeVisitor { VLOG_CONTEXT << "AOT"; Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); + Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); std::string interface_api = executor_config->GetAttr("interface-api").value_or("packed"); - Integer workspace_byte_alignment = - executor_config->GetAttr("workspace-byte-alignment").value_or(16); bool unpacked_api = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - // Validate choice of use_unpacked_api_ and use_call_cpacked_ + // Validate choice of unpacked_api and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { if (unpacked_api == true) { call_type_ = CallType::kUnpacked; @@ -1173,12 +1189,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } Array outputs = Array(outputs_begin_iterator, main_func_params_end_iterator - devices.size()); - std::vector output_var_names; - for (const tir::Var& output : outputs) { - output_var_names.push_back(output->name_hint); - } - Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func); // Parallel for loops are not supported in AoT codegen. lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); @@ -1193,7 +1204,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.function_metadata = std::move(function_metadata_); // Legalize AOT if needed. This means that all the packed calls - // need to be wrapped in TVMValues (unless use_unpacked_api is set) + // need to be wrapped in TVMValues (unless unpacked_api is set) if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { auto pack_calls = tir::transform::LegalizePackedCalls(); lowered_mod = pack_calls(lowered_mod); @@ -1235,11 +1246,32 @@ class AOTExecutorCodegen : public MixedModeVisitor { ->GetAttr>(tvm::attr::kIOTensorPoolAllocations) .value_or({}); - ret.metadata = - ExecutorCodegenMetadata(inputs, input_tensor_types, output_var_names, output_tensor_types, - pool_vars, devices, runtime::kTvmExecutorAot, mod_name, - interface_api, unpacked_api, pool_var_info, io_pool_allocations); + std::vector output_var_names; + if (auto opt = func->GetAttr>("output_tensor_names")) { + Array output_tensor_names = opt.value(); + for (size_t i = 0; i < output_tensor_names.size(); ++i) { + output_var_names.push_back(output_tensor_names[i]); + } + } + + // If output names have not been specified then generate default output names + if (output_var_names.size() == 0) { + if (return_sid_.size() == 1) { + output_var_names.push_back(String("output")); + } else { + for (size_t i = 0; i < return_sid_.size(); ++i) { + output_var_names.push_back(String("output" + std::to_string(i))); + } + } + } + + Array output_tensor_types{final_aot_allocator.GetReturnTtypes()}; + ret.metadata = ExecutorCodegenMetadata( + inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, + runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, + GetModuleWorkspaceByteAlignment(mod), GetModuleConstantByteAlignment(mod), pool_var_info, + io_pool_allocations); return ret; } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 8c1d83d39b09..578a62ca0259 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -192,8 +192,8 @@ class RelayBuildModule : public runtime::ModuleNode { [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 6); - this->Build(args[0], args[1], args[2], args[3], args[4], args[5]); + ICHECK_EQ(args.num_args, 8); + this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); }); } else if (name == "list_params") { return PackedFunc( @@ -303,13 +303,15 @@ class RelayBuildModule : public runtime::ModuleNode { * \param runtime Runtime to codegen for * \param mod_name Name of the module */ - void Build(IRModule mod, const Array& raw_targets, const Executor& executor, - const Runtime& runtime, const WorkspaceMemoryPools& workspace_memory_pools, - const String& mod_name) { + void Build(IRModule mod, const Array& raw_targets, const tvm::Target& target_host, + const Executor& executor, const Runtime& runtime, + const WorkspaceMemoryPools& workspace_memory_pools, + const ConstantMemoryPools& constant_memory_pools, const String mod_name) { VLOG_CONTEXT << "Build"; executor_ = executor; runtime_ = runtime; workspace_memory_pools_ = workspace_memory_pools; + constant_memory_pools_ = constant_memory_pools; config_ = CompilationConfig(PassContext::Current(), raw_targets); VLOG(1) << "Using compilation config:" << std::endl << config_; BuildRelay(std::move(mod), mod_name); @@ -414,7 +416,8 @@ class RelayBuildModule : public runtime::ModuleNode { IRModule func_module = WithAttrs(IRModule::FromExpr(func), {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}, - {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_}}); + {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_}, + {tvm::attr::kConstantMemoryPools, constant_memory_pools_}}); // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); @@ -476,6 +479,8 @@ class RelayBuildModule : public runtime::ModuleNode { Runtime runtime_; /*! \brief Workspace memory pools to codegen for */ WorkspaceMemoryPools workspace_memory_pools_; + /*! \brief Constant memory pools to codegen for */ + ConstantMemoryPools constant_memory_pools_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 581fbdf2cdf1..bb9706ba86f9 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -91,7 +91,8 @@ TVM_REGISTER_EXECUTOR("aot") .add_attr_option("link-params", Bool(true)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") - .add_attr_option("workspace-byte-alignment"); + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constant-byte-alignment"); TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 3c6e642e846e..bd3047e2862c 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -183,6 +183,7 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata( Array inputs, Array input_tensor_types, Array outputs, Array output_tensor_types, Array pools, Array devices, String executor, String mod_name, String interface_api, bool unpacked_api, + Integer workspace_alignment, Integer constant_alignment, Map pool_inputs, Map io_pool_allocations) { auto n = make_object(); @@ -196,6 +197,8 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata( n->interface_api = interface_api; n->unpacked_api = unpacked_api; n->mod_name = mod_name; + n->workspace_alignment = workspace_alignment; + n->constant_alignment = constant_alignment; n->pool_inputs = pool_inputs; n->io_pool_allocations = io_pool_allocations; data_ = std::move(n); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 70080254c414..67924a7835fb 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -81,6 +82,10 @@ class ExecutorCodegenMetadataNode : public Object { String interface_api; /*! \brief The internal API (packed or unpacked) in use */ bool unpacked_api; + /*! \brief Alginment of the workspace in bytes */ + Integer workspace_alignment; + /*! \brief Alginment of the constants in bytes */ + Integer constant_alignment; /*! \brief the input var names that correspond to pool_inputs */ Optional> pool_inputs; /*! \brief the I/O tensor to PoolAllocations if any*/ @@ -97,6 +102,8 @@ class ExecutorCodegenMetadataNode : public Object { v->Visit("devices", &devices); v->Visit("executor", &executor); v->Visit("unpacked_api", &unpacked_api); + v->Visit("workspace_alignment", &workspace_alignment); + v->Visit("constant_alignment", &constant_alignment); v->Visit("pool_inputs", &pool_inputs); v->Visit("io_pool_allocations", &io_pool_allocations); } @@ -110,14 +117,15 @@ class ExecutorCodegenMetadataNode : public Object { */ class ExecutorCodegenMetadata : public ObjectRef { public: - TVM_DLL ExecutorCodegenMetadata( - Array inputs, Array input_tensor_types, Array outputs, - Array output_tensor_types, Array pools, Array devices, - String executor, String mod_name, String interface_api = "packed", bool unpacked_api = false, - Map pool_inputs = - Map(), - Map io_pool_allocations = {{}}); - + TVM_DLL ExecutorCodegenMetadata(Array inputs, Array input_tensor_types, + Array outputs, Array output_tensor_types, + Array pools, Array devices, String executor, + String mod_name, String interface_api = "packed", + bool unpacked_api = false, Integer workspace_alignment = 16, + Integer constant_alignment = 16, + Map pool_inputs = + Map(), + Map io_pool_allocations = {}); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef, ExecutorCodegenMetadataNode); }; diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index 732da14695eb..985c857ed55f 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -26,7 +26,9 @@ #include "aot_executor.h" #include +#include +#include #include #include "../meta_data.h" @@ -62,9 +64,31 @@ AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& output->dtype(), devices_[0])); } - for (auto pool : metadata_->pools()) { - args_.emplace_back(NDArray::Empty(ShapeTuple(pool->shape().begin(), pool->shape().end()), - pool->dtype(), devices_[0])); + // USMP is used + if (metadata_->num_workspace_pools()) { + // merge all constants into one ndarray + int64_t blob_len = 0; + for (const auto& c : metadata_->constant_pools()) { + auto data = c->data(); + int64_t byte_size = GetDataSize(*data.operator->()) + c->byte_offset(); + blob_len = blob_len > byte_size ? blob_len : byte_size; + } + ICHECK(blob_len < std::numeric_limits::max()); + NDArray ci = NDArray::Empty({blob_len}, DataType::UInt(8), devices_[0]); + for (const auto& c : metadata_->constant_pools()) { + auto data = c->data(); + data.CopyToBytes(static_cast(ci->data) + c->byte_offset(), + GetDataSize(*data.operator->())); + } + // Emplace constant node pool only if workspace pools supplied + args_.emplace_back(ci); + + int32_t pool_len = 0; + for (auto pool : metadata_->workspace_pools()) { + pool_len = + GetDataSize(*NDArray::Empty({pool->shape()}, pool->dtype(), devices_[0]).operator->()); + args_.emplace_back(NDArray::Empty({pool_len}, DataType::UInt(8), devices_[0])); + } } } diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c index 1724fabec4a0..a40c1d530fa9 100644 --- a/src/runtime/crt/aot_executor/aot_executor.c +++ b/src/runtime/crt/aot_executor/aot_executor.c @@ -39,7 +39,8 @@ static void DumpMetadata(const TVMMetadata* md) { LOG_DEBUG("\tversion=%" PRId64 "\n", md->version); LOG_DEBUG("\tnum_inputs=%" PRId64 "\n", md->num_inputs); LOG_DEBUG("\tnum_outputs=%" PRId64 "\n", md->num_outputs); - LOG_DEBUG("\tnum_pools=%" PRId64 "\n", md->num_pools); + LOG_DEBUG("\tnum_workspace_pools=%" PRId64 "\n", md->num_workspace_pools); + LOG_DEBUG("\tnum_constant_pools=%" PRId64 "\n", md->num_constant_pools); int i; @@ -51,8 +52,12 @@ static void DumpMetadata(const TVMMetadata* md) { LOG_DEBUG("\toutput[%d]: %s\n", i, md->outputs[i].name); } - for (i = 0; i < md->num_pools; ++i) { - LOG_DEBUG("\tpools[%d]: %s\n", i, md->pools[i].name); + for (i = 0; i < md->num_workspace_pools; ++i) { + LOG_DEBUG("\tworkspace_pools[%d]: %s\n", i, md->workspace_pools[i].name); + } + + for (i = 0; i < md->num_constant_pools; ++i) { + LOG_DEBUG("\tconstant_pools[%d]: %s\n", i, md->constant_pools[i].name_hint); } } @@ -160,7 +165,7 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, DumpMetadata(md); - executor->num_args = md->num_inputs + md->num_outputs + md->num_pools; + executor->num_args = md->num_inputs + md->num_outputs + md->num_workspace_pools; tvm_crt_error_t err = TVMPlatformMemoryAllocate(executor->num_args * sizeof(*executor->args), executor->device, (void**)(&executor->args)); @@ -198,16 +203,17 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, TVMNDArray_IncrementReference(array); } - for (i = 0; i < md->num_pools; ++i) { - LOG_DEBUG("pools allocate[%d]: %s\n", i, md->pools[i].name); + for (i = 0; i < md->num_workspace_pools; ++i) { + LOG_DEBUG("pools allocate[%d]: %s\n", i, md->workspace_pools[i].name); - status = TVMNDArray_Empty(md->pools[i].num_shape, md->pools[i].shape, md->pools[i].dtype, - executor->device, &executor->args[arg_idx++]); + status = TVMNDArray_Empty(md->workspace_pools[i].num_shape, md->workspace_pools[i].shape, + md->workspace_pools[i].dtype, executor->device, + &executor->args[arg_idx++]); if (status != 0) { return status; } } - + CHECK_EQ(0, md->num_constant_pools, "Constant pools not supported"); return status; } diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 8e034cc94d3a..2120ffe40d67 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -41,8 +41,13 @@ ArrayAccessor MetadataNode::inputs() { ArrayAccessor MetadataNode::outputs() { return ArrayAccessor(data_->outputs, data_->num_outputs); } -ArrayAccessor MetadataNode::pools() { - return ArrayAccessor(data_->pools, data_->num_pools); +ArrayAccessor MetadataNode::workspace_pools() { + return ArrayAccessor(data_->workspace_pools, + data_->num_workspace_pools); +} +ArrayAccessor MetadataNode::constant_pools() { + return ArrayAccessor(data_->constant_pools, + data_->num_constant_pools); } TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); @@ -68,6 +73,12 @@ TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); const char* TensorInfoNode::get_c_struct_name() const { return "TVMTensorInfo"; } +ConstantInfoMetadata::ConstantInfoMetadata(const struct ::TVMConstantInfo* data) + : MetadataBase{make_object(data)} {} +TVM_REGISTER_OBJECT_TYPE(ConstantInfoMetadataNode); + +const char* ConstantInfoMetadataNode::get_c_struct_name() const { return "TVMConstantInfo"; } + } // namespace metadata class MetadataModuleNode : public ::tvm::runtime::ModuleNode { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 2a66ff37c949..07e1881c9840 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -25,6 +25,7 @@ #include "codegen_cpu.h" #include +#include #include #include @@ -34,7 +35,6 @@ #include "../func_registry_generator.h" #include "../metadata_utils.h" - namespace tvm { namespace codegen { @@ -989,7 +989,8 @@ class MetadataTypeDefiner : public AttrVisitor { elements_.emplace_back(llvm_types_->t_data_type); } void Visit(const char* key, runtime::NDArray* value) final { - CHECK(false) << "Do not support serializing NDArray"; + elements_.emplace_back(llvm_types_->t_int64); + elements_.emplace_back(llvm_types_->t_void_p); } private: @@ -1025,8 +1026,10 @@ class MetadataTypeDefiner : public AttrVisitor { CHECK(false) << "Do not support handle"; break; case MetadataKind::kMetadata: - elements_.emplace_back( - llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key])); + if (llvm_types_->structs_by_type_key.count(arr->type_key)) { + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key])); + } break; default: CHECK(false) << "Unsupported metadata kind " << arr->kind; @@ -1046,12 +1049,8 @@ class MetadataTypeDefiner : public AttrVisitor { } void DefineType(runtime::metadata::MetadataBase metadata) { + ICHECK(elements_.empty()); ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - for (auto e : elements_) { - std::string value; - llvm::raw_string_ostream os(value); - e->print(os, true); - } llvm_types_->structs_by_type_key[metadata->GetTypeKey()] = llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name()); elements_.clear(); @@ -1104,8 +1103,14 @@ class MetadataSerializerLLVM : public AttrVisitor { llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); } + // Serializing NDArray as tuple of len, data void Visit(const char* key, runtime::NDArray* value) final { - CHECK(false) << "Do not support serializing NDArray"; + std::string bytes; + dmlc::MemoryStringStream stream(&bytes); + value->Save(&stream); + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, bytes.length(), true /* isSigned */)); + elements_.back().emplace_back(codegen_->GetConstString(bytes)); } void VisitMetadata(runtime::metadata::MetadataBase metadata) { @@ -1219,7 +1224,17 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, }; + // create sample ConstantInfoMetadata instance for MetadataTypeDefiner + std::string bytes; + runtime::NDArray ci = runtime::NDArray::Empty({0}, DataType::UInt(8), Device{kDLCPU}); + dmlc::MemoryStringStream stream(&bytes); + ci.Save(&stream); + TVMConstantInfo di = + TVMConstantInfo{"default-none", 0, static_cast(bytes.size()), bytes.c_str()}; + std::vector queue; + queue.push_back(runtime::metadata::ConstantInfoMetadata(&di)); + metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; discover_complex.Discover(metadata); @@ -1235,9 +1250,8 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { function_ = llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, - "get_c_metadata", module_.get()); + runtime::symbol::tvm_get_c_metadata, module_.get()); SetTargetAttributes(function_); - function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 981aef2f6c06..dce75e01aa60 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1492,7 +1492,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { auto data = op->data.value(); - auto array = NDArrayToLLVMArray(ctx_, data); + auto array = codegen::NDArrayToLLVMArray(ctx_, data); std::string symbol_name = op->buffer_var->name_hint; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 0f259d6a6cf9..e3a6a3b954fa 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -47,6 +47,7 @@ #include "../../runtime/thread_storage_scope.h" #include "../../tir/transforms/ir_utils.h" +#include "codegen_params.h" #include "llvm_common.h" namespace tvm { @@ -185,6 +186,9 @@ class CodeGenLLVM : public ExprFunctor, llvm::Constant* GetGlobalConstant( llvm::Constant* const_data, const std::string& name = "", llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); + inline llvm::ConstantArray* NDArrayToLLVMArray(::tvm::runtime::NDArray arr) { + return codegen::NDArrayToLLVMArray(ctx_, arr); + } protected: /*! diff --git a/src/target/metadata.cc b/src/target/metadata.cc index adf4cba3e610..35df3ada0000 100644 --- a/src/target/metadata.cc +++ b/src/target/metadata.cc @@ -42,6 +42,12 @@ TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode, return ::tvm::runtime::make_object(); }); +TVM_REGISTER_REFLECTION_VTABLE(VisitableConstantInfoMetadataNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + } // namespace metadata } // namespace target } // namespace tvm diff --git a/src/target/metadata.h b/src/target/metadata.h index 426e8616070a..7551592ac5ab 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -23,7 +23,7 @@ */ #ifndef TVM_TARGET_METADATA_H_ #define TVM_TARGET_METADATA_H_ - +#include #include #include @@ -74,17 +74,31 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { int64_t num_outputs_cpp = num_outputs(); v->Visit("num_outputs", &num_outputs_cpp); auto pools_array = Array(); - auto pools_accessor = pools(); - pools_array.reserve(num_pools()); - for (int64_t i = 0; i < num_pools(); ++i) { + auto pools_accessor = workspace_pools(); + pools_array.reserve(num_workspace_pools()); + for (int64_t i = 0; i < num_workspace_pools(); ++i) { pools_array.push_back(::tvm::runtime::metadata::TensorInfo{pools_accessor[i]}); } - ::tvm::runtime::metadata::MetadataArray pools_metadata_array{ + ::tvm::runtime::metadata::MetadataArray workspace_pools_metadata_array{ pools_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, ::tvm::runtime::metadata::TensorInfoNode::_type_key}; - v->Visit("pools", &pools_metadata_array); - int64_t num_pools_cpp = num_pools(); - v->Visit("num_pools", &num_pools_cpp); + v->Visit("workspace_pools", &workspace_pools_metadata_array); + int64_t num_workspace_pools_cpp = num_workspace_pools(); + v->Visit("num_workspace_pools", &num_workspace_pools_cpp); + + auto consts_array = Array(); + auto consts_accessor = constant_pools(); + consts_array.reserve(num_constant_pools()); + for (int64_t i = 0; i < num_constant_pools(); ++i) { + consts_array.push_back(::tvm::runtime::metadata::ConstantInfoMetadata{consts_accessor[i]}); + } + + int64_t num_const_pools_cpp = num_constant_pools(); + ::tvm::runtime::metadata::MetadataArray constant_pools_metadata_array{ + consts_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::ConstantInfoMetadataNode::_type_key}; + v->Visit("constant_pools", &constant_pools_metadata_array); + v->Visit("num_constant_pools", &num_const_pools_cpp); ::std::string mod_name_cpp{data()->mod_name}; v->Visit("mod_name", &mod_name_cpp); } @@ -100,22 +114,27 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode { public: InMemoryMetadataNode() - : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, {} /* pools */, - "" /* mod_name */) {} + : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, + {} /* workspace_pools */, {} /* constant_pools */, "" /* mod_name */) { + } InMemoryMetadataNode(int64_t version, const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, - const ::std::vector<::tvm::runtime::metadata::TensorInfo>& pools, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& workspace_pools, + const ::std::vector<::tvm::ConstantInfo>& constant_pools, const ::tvm::runtime::String mod_name) : VisitableMetadataNode{&storage_}, inputs_{new struct TVMTensorInfo[inputs.size()]}, inputs_objs_{inputs}, outputs_{new struct TVMTensorInfo[outputs.size()]}, outputs_objs_{outputs}, - pools_{new struct TVMTensorInfo[pools.size()]}, - pools_objs_{pools}, + workspace_pools_{new struct TVMTensorInfo[workspace_pools.size()]}, + workspace_pools_objs_{workspace_pools}, + constant_pools_{new struct TVMConstantInfo[constant_pools.size()]}, + constant_pools_objs_{constant_pools}, mod_name_{mod_name}, - storage_{version, nullptr, 0, nullptr, 0, nullptr, 0, mod_name_.c_str()} { + storage_{version, nullptr, 0ull, nullptr, 0ull, + nullptr, 0ull, nullptr, 0ull, mod_name_.c_str()} { storage_.inputs = inputs_.get(); storage_.num_inputs = inputs.size(); for (unsigned int i = 0; i < inputs.size(); ++i) { @@ -126,10 +145,33 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo for (unsigned int i = 0; i < outputs.size(); ++i) { outputs_.get()[i] = *outputs[i]->data(); } - storage_.pools = pools_.get(); - storage_.num_pools = pools.size(); - for (unsigned int i = 0; i < pools.size(); ++i) { - pools_.get()[i] = *pools[i]->data(); + storage_.workspace_pools = workspace_pools_.get(); + storage_.num_workspace_pools = workspace_pools.size(); + for (unsigned int i = 0; i < workspace_pools.size(); ++i) { + workspace_pools_.get()[i] = *workspace_pools[i]->data(); + } + storage_.constant_pools = constant_pools_.get(); + storage_.num_constant_pools = constant_pools.size(); + for (size_t i = 0; i < constant_pools.size(); ++i) { + constant_pools_.get()[i].name_hint = constant_pools[i]->name_hint.c_str(); + constant_pools_.get()[i].byte_offset = constant_pools[i]->byte_offset; + + std::string bytes; + dmlc::MemoryStringStream stream(&bytes); + auto data = constant_pools[i]->data; + data.Save(&stream); + // Allocated mem freed in destructor + constant_pools_.get()[i].data_len = bytes.size(); + char* a = reinterpret_cast(malloc(bytes.size())); + constant_pools_.get()[i].data_bytes = a; + memcpy(a, bytes.c_str(), bytes.size()); + } + } + + ~InMemoryMetadataNode() { + // frees allocated mem for const_objs_ + for (int i = 0; i < storage_.num_constant_pools; ++i) { + free(const_cast(constant_pools_.get()[i].data_bytes)); } } @@ -138,8 +180,10 @@ class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNo std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_; ::std::unique_ptr outputs_; std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_; - ::std::unique_ptr pools_; - std::vector<::tvm::runtime::metadata::TensorInfo> pools_objs_; + ::std::unique_ptr workspace_pools_; + std::vector<::tvm::runtime::metadata::TensorInfo> workspace_pools_objs_; + ::std::unique_ptr constant_pools_; + std::vector<::tvm::ConstantInfo> constant_pools_objs_; ::std::string mod_name_; struct ::TVMMetadata storage_; }; @@ -190,6 +234,25 @@ class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorIn struct ::TVMTensorInfo storage_; }; +class VisitableConstantInfoMetadataNode + : public ::tvm::runtime::metadata::ConstantInfoMetadataNode { + public: + explicit VisitableConstantInfoMetadataNode(const struct ::TVMConstantInfo* data) + : ConstantInfoMetadataNode{data} {} + VisitableConstantInfoMetadataNode() : ConstantInfoMetadataNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + ::std::string name_cpp{name_hint()}; + v->Visit("name_hint", &name_cpp); + + uint64_t byte_offset_cpp{byte_offset()}; + v->Visit("byte_offset", &byte_offset_cpp); + + ::tvm::runtime::NDArray data_cpp = data(); + v->Visit("data", &data_cpp); + } +}; + } // namespace metadata } // namespace target } // namespace tvm diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 97299c63752d..e5ca82d5c099 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -114,15 +114,28 @@ static runtime::metadata::Metadata ConvertMetaData( std::vector pools; for (size_t i = 0; i < metadata->pools.size(); ++i) { auto var = metadata->pools[i]; - pools.push_back( - runtime::metadata::TensorInfo(make_object( - var->name_hint, - std::vector{metadata->pool_inputs.value()[var]->allocated_size}, - tvm::runtime::DataType{kDLUInt, 8, 1}))); + auto api = metadata->pool_inputs.value()[var]; + if (api->pool_info.as()) { + pools.push_back( + runtime::metadata::TensorInfo(make_object( + var->name_hint, std::vector{api->allocated_size}, + tvm::runtime::DataType{kDLUInt, 8, 1}))); + } } + std::vector consts; + for (const auto& kv : metadata->pool_inputs.value()) { + const auto& api = kv.second; + if (const auto* pi = api->pool_info.as()) { + if (pi->is_internal) { + for (const auto ci : pi->constant_info_array) { + consts.emplace_back(ci->name_hint, ci->byte_offset, ci->data); + } + } + } + } auto n = make_object( - runtime::metadata::kMetadataVersion, inputs, outputs, pools, metadata->mod_name); + runtime::metadata::kMetadataVersion, inputs, outputs, pools, consts, metadata->mod_name); return runtime::metadata::Metadata(std::move(n)); } diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h index 977a0f412bb5..f21de2986e33 100644 --- a/src/target/metadata_utils.h +++ b/src/target/metadata_utils.h @@ -107,7 +107,12 @@ class DiscoverComplexTypesVisitor : public AttrVisitor { * \param queue An ordered map which holds the */ explicit DiscoverComplexTypesVisitor(std::vector* queue) - : queue_{queue} {} + : queue_{queue} { + int i = 0; + for (auto q : *queue) { + type_key_to_position_[q->GetTypeKey()] = i++; + } + } void Visit(const char* key, double* value) final; void Visit(const char* key, int64_t* value) final; diff --git a/src/target/source/codegen_params.cc b/src/target/source/codegen_params.cc index 798ef73f0fa8..b052727e5d2e 100644 --- a/src/target/source/codegen_params.cc +++ b/src/target/source/codegen_params.cc @@ -238,7 +238,7 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& } default: - CHECK(false) << "Data type not supported"; + CHECK(false) << "Data type '" << arr_type << "' not supported"; } os.flags(old_fmtflags); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 2c4993419f58..41269cab64de 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,12 +23,16 @@ */ #include "source_module.h" +#include #include #include #include #include #include +#include +#include +#include #include #include #include @@ -41,7 +45,9 @@ #include "../func_registry_generator.h" #include "../metadata.h" #include "../metadata_utils.h" +#include "codegen_params.h" #include "codegen_source_base.h" +#include "tvm/relay/executor.h" namespace tvm { namespace codegen { @@ -249,15 +255,17 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { return reference_arg + "_tvm_value"; } - void GenerateInternalWorkspaceBuffers() { + void GenerateInternalBuffers() { if (metadata_->pool_inputs.defined()) { for (const auto& kv : metadata_->pool_inputs.value()) { tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second; if (allocated_pool_info->pool_info->is_internal) { - code_ << "__attribute__((section(\".bss.noinit.tvm\"), "; - code_ << "aligned(" << 16 << ")))\n"; - code_ << "static uint8_t " << allocated_pool_info->pool_info->pool_name << "[" - << allocated_pool_info->allocated_size->value << "];\n"; + if (const auto* pool_info = allocated_pool_info->pool_info.as()) { + GenerateConstantBuffer(pool_info, allocated_pool_info->allocated_size->value); + } else { + GenerateWorkspaceBuffer(allocated_pool_info->pool_info.as(), + allocated_pool_info->allocated_size->value); + } } } } @@ -283,6 +291,55 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "}\n\n"; } + void GenerateConstantBuffer(const ConstantPoolInfoNode* pool_info, size_t allocated_size) { + size_t offset = 0; + if (pool_info->constant_info_array.size() > 0) { + // Pool is RO, form an initialized struct + code_ << "__attribute__((section(\".rodata.tvm\"), "; + code_ << "))\n"; + code_ << "static struct " << pool_info->pool_name << " {\n"; + // emit struct field names + std::vector const_info_vec(pool_info->constant_info_array.begin(), + pool_info->constant_info_array.end()); + std::sort(const_info_vec.begin(), const_info_vec.end(), + [](const ConstantInfo& a, const ConstantInfo& b) { + return a->byte_offset->value < b->byte_offset->value; + }); + for (const auto& const_info : const_info_vec) { + const auto& data = const_info->data; + const auto& offs = const_info->byte_offset; + int64_t num_elements = std::accumulate(data.Shape().begin(), data.Shape().end(), 1, + std::multiplies()); + code_ << " "; + codegen_c_base_.PrintType(data.DataType(), code_); + code_ << " " << const_info->name_hint << "[" << num_elements + << "] __attribute__((packed, aligned(" << metadata_->constant_alignment << ")));"; + code_ << " // " << num_elements * data.DataType().bytes() + << " bytes, aligned offset: " << offs << "\n"; + } + code_ << "} " << pool_info->pool_name << " = {\n"; + + // emit struct field initialization data + for (const auto& const_info : const_info_vec) { + code_ << " ." << const_info->name_hint << " = {\n"; + codegen::NDArrayDataToC(const_info->data, 4, code_); + code_ << " },\n"; + } + code_ << "};"; + code_ << "// of total size " << allocated_size << " bytes, aligned: " << offset << " bytes\n"; + } else { + LOG(FATAL) << "No constant data in constant pool found " + << PrettyPrint(GetRef(pool_info)); + } + } + + void GenerateWorkspaceBuffer(const WorkspacePoolInfoNode* pool_info, size_t allocated_size) { + code_ << "__attribute__((section(\".bss.noinit.tvm\"), "; + code_ << "aligned(" << metadata_->workspace_alignment << ")))\n"; + code_ << "static uint8_t " << pool_info->pool_name << "["; + code_ << allocated_size << "];\n"; + } + bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) { if (metadata_->pool_inputs.defined()) { Map allocated_pool_infos = @@ -549,7 +606,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "extern \"C\" {\n"; code_ << "#endif\n"; - GenerateInternalWorkspaceBuffers(); + GenerateInternalBuffers(); if (metadata_->unpacked_api) { if (metadata_->interface_api == "c") { @@ -646,9 +703,26 @@ class MetadataSerializer : public AttrVisitor { WriteKey(key); } + // Serialiding NDArray as tuple of len, data void Visit(const char* key, runtime::NDArray* value) final { - // TODO(areusch): probably we could consolidate --link-params here, tho... - ICHECK(false) << "do not support serializing NDArray as metadata"; + WriteComma(); + std::string bytes; + dmlc::MemoryStringStream stream(&bytes); + value->Save(&stream); + // Serializing length of the data of NDArray + code_ << stream.Tell(); + WriteComma(); + // Serializing NDArray as bytestream + code_ << "\""; + std::stringstream ss; + char buf[6] = {0}; + for (uint8_t c : bytes) { + snprintf(buf, sizeof(buf), "\\x%02x", c); + ss << buf; + } + std::string as_bytes(ss.str()); + code_ << as_bytes; + code_ << "\"\n"; } void VisitArray(runtime::metadata::MetadataArray array) { @@ -722,7 +796,11 @@ class MetadataSerializer : public AttrVisitor { if (key != nullptr) { // NOTE: outermost call passes nullptr key address_.push_back(key); } + WriteComma(); + code_ << "{\n"; + is_first_item_ = true; ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + code_ << "}\n"; if (key != nullptr) { // NOTE: outermost call passes nullptr key address_.pop_back(); } @@ -790,7 +868,7 @@ class MetadataSerializer : public AttrVisitor { // Finally, emit overall struct. address_.push_back(metadata::kMetadataGlobalSymbol); - code_ << "static const struct TVMMetadata " << metadata::AddressFromParts(address_) << " = {" + code_ << "static const struct TVMMetadata " << metadata::AddressFromParts(address_) << "[1] = {" << std::endl; Visit(nullptr, &metadata); code_ << "};" << std::endl; diff --git a/src/tir/analysis/calculate_workspace.cc b/src/tir/analysis/calculate_workspace.cc index 49ddaf613c6d..11593bb443a7 100644 --- a/src/tir/analysis/calculate_workspace.cc +++ b/src/tir/analysis/calculate_workspace.cc @@ -26,10 +26,12 @@ #include #include #include +#include namespace tvm { namespace tir { +template class WorkspaceCalculator : public StmtExprVisitor { public: WorkspaceCalculator() = default; @@ -37,38 +39,29 @@ class WorkspaceCalculator : public StmtExprVisitor { size_t byte_alignment = tvm::runtime::kDefaultWorkspaceAlignment; private: - void VisitStmt_(const AllocateNode* op) override; - size_t CalculateExtentsSize(const AllocateNode* op); - size_t GetByteAlignedSize(size_t non_aligned_size); + void VisitStmt_(const T* op) override; + size_t GetByteAlignedSize(Integer non_aligned_size); + size_t CalculateExtentsSize(const DataType& dtype, const Array& extents); size_t current_size = 0; size_t max_size = 0; }; -size_t WorkspaceCalculator::operator()(const PrimFunc& func) { +template +size_t WorkspaceCalculator::operator()(const PrimFunc& func) { this->VisitStmt(func->body); return this->max_size; } -size_t WorkspaceCalculator::GetByteAlignedSize(size_t non_aligned_size) { - return ((non_aligned_size + byte_alignment - 1) / byte_alignment) * byte_alignment; +template +size_t WorkspaceCalculator::GetByteAlignedSize(Integer non_aligned_size) { + return non_aligned_size.defined() + ? ((non_aligned_size + byte_alignment - 1) / byte_alignment) * byte_alignment + : 0; } -size_t WorkspaceCalculator::CalculateExtentsSize(const AllocateNode* op) { - size_t element_size_bytes = op->dtype.bytes(); - size_t num_elements = 1; - for (const auto& ext : op->extents) { - if (ext->IsInstance()) { - num_elements *= Downcast(ext)->value; - } else { - // We cant statically calculate workspace for dynamic shapes - num_elements = 0; - } - } - return GetByteAlignedSize(num_elements * element_size_bytes); -} - -void WorkspaceCalculator::VisitStmt_(const AllocateNode* op) { - auto size = CalculateExtentsSize(op); +template +void WorkspaceCalculator::VisitStmt_(const T* op) { + auto size = GetByteAlignedSize(usmp::CalculateExtentsSize(op)); current_size += size; if (current_size > max_size) { max_size = current_size; @@ -77,12 +70,23 @@ void WorkspaceCalculator::VisitStmt_(const AllocateNode* op) { current_size -= size; } -size_t CalculateWorkspaceBytes(const PrimFunc& func, const Integer& workspace_byte_alignment) { - WorkspaceCalculator wc; - wc.byte_alignment = workspace_byte_alignment->value; +size_t CalculateConstantBytes(const PrimFunc& func, const Integer& byte_alignment) { + WorkspaceCalculator wc; + wc.byte_alignment = byte_alignment->value; + return wc(func); +} + +size_t CalculateWorkspaceBytes(const PrimFunc& func, const Integer& byte_alignment) { + WorkspaceCalculator wc; + wc.byte_alignment = byte_alignment->value; return wc(func); } +TVM_REGISTER_GLOBAL("tir.analysis.calculate_constant_bytes") + .set_body_typed([](PrimFunc func, Integer constant_byte_alignment) { + return static_cast(CalculateConstantBytes(func, constant_byte_alignment)); + }); + TVM_REGISTER_GLOBAL("tir.analysis.calculate_workspace_bytes") .set_body_typed([](PrimFunc func, Integer workspace_byte_alignment) { return static_cast(CalculateWorkspaceBytes(func, workspace_byte_alignment)); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 43c2d3745964..2b337520a249 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -437,7 +437,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // depending on the type of ObjectRef, it will either // create AllocateConstNode with irmod_storage_idx or data AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, Span span) { + ObjectRef data_or_idx, Stmt body, Map annotations, + Span span) { ICHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" @@ -456,6 +457,7 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext node->dtype = dtype; node->extents = std::move(extents); node->body = std::move(body); + node->annotations = annotations; node->span = std::move(span); if (data_or_idx->IsInstance()) { node->data = Optional(Downcast(data_or_idx)); @@ -485,8 +487,9 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents } TVM_REGISTER_GLOBAL("tir.AllocateConst") .set_body_typed([](Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, Span span) { - return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, span); + ObjectRef data_or_idx, Stmt body, Map annotations, + Span span) { + return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations, span); }); TVM_REGISTER_NODE_TYPE(AllocateConstNode); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index f97f91a1e501..67100ebd334c 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1433,7 +1433,7 @@ class StorageFlattener : public StmtExprMutator { << op->buffer_var->name_hint; } return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx, - stmt->body, stmt->span); + stmt->body, stmt->annotations, stmt->span); } Stmt VisitStmt_(const LetStmtNode* op) final { diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index 8246ffc219c6..cae01ee85969 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -61,11 +61,20 @@ size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_off */ bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, const size_t& size_bytes) { - if (candidate_pool->size_hint_bytes == kUnrestrictedPoolSizeHint) { + Integer size_hint_bytes = -1; + if (const auto* p = candidate_pool.as()) { + size_hint_bytes = p->size_hint_bytes; + } else if (const auto* p = candidate_pool.as()) { + size_hint_bytes = p->size_hint_bytes; + } else { + LOG(FATAL) << "Pool '" << candidate_pool->GetTypeKey() << "' is not supported"; + } + + if (size_hint_bytes == kUnrestrictedPoolSizeHint) { // this means pool is not bounded return true; } - auto pool_size = static_cast(candidate_pool->size_hint_bytes->value); + auto pool_size = static_cast(size_hint_bytes); auto max_address = next_offset + size_bytes; if (max_address <= pool_size) { return true; diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index b90cfddb7153..4e98116f8a17 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -73,6 +73,7 @@ class BufferInfoExtractor : public StmtExprVisitor { private: void VisitStmt(const Stmt& n) override; void VisitStmt_(const AllocateNode* op) override; + void VisitStmt_(const AllocateConstNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; @@ -81,6 +82,7 @@ class BufferInfoExtractor : public StmtExprVisitor { void UpdateAliases(const Array& args, const PrimFunc& func); void RecordAllocateNodeInfo(const AllocateNode* op); + void RecordAllocateConstNodeInfo(const AllocateConstNode* op); void VisitPrimFunc(const PrimFunc& func, const Call& call); /*! @@ -148,6 +150,12 @@ class BufferInfoExtractor : public StmtExprVisitor { * loops structure. */ std::unordered_set allocate_nodes; + /* + * \brief We record the live allocate_const_nodes because once in loops + * the liveness range has to be extended to the whole of the nested + * loops structure. + */ + std::unordered_set allocate_const_nodes; /*! * \brief This is recorded to extend the liveness of all allocates within * nested loop structure. @@ -292,9 +300,57 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { current_scope_info.allocate_nodes.erase(GetRef(op)); } +void BufferInfoExtractor::VisitStmt_(const AllocateConstNode* op) { + ScopeInfo& current_scope_info = scope_stack_.top(); + RecordAllocateConstNodeInfo(op); + StmtExprVisitor::VisitStmt(op->body); + current_scope_info.allocate_const_nodes.erase(GetRef(op)); +} + +void BufferInfoExtractor::RecordAllocateConstNodeInfo(const AllocateConstNode* op) { + if (!op->annotations.count(kPoolCandidatesAllocateAttr)) { + return; + } + Integer size_bytes = CalculateExtentsSize(op); + ICHECK(size_bytes.defined()) << "constant node size should be defined"; + const auto& buffer_var = op->buffer_var; + if (allocate_infos.find(buffer_var) == allocate_infos.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = scope_stack_.top().func; + Optional executor_config = + module_->GetAttr(tvm::attr::kExecutor); + Integer alignment = 16; + if (executor_config) { + alignment = + executor_config.value()->GetAttr("constant-byte-alignment").value_or(alignment); + } + auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_var->name_hint), size_bytes, + pool_candidates, alignment); + auto allocate = GetRef(op); + allocate_infos[buffer_var] = + AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call}; + buffer_info_map_.Set(buffer_info, allocate); + } else { + // Update the allocate info with the latest call + AllocateInfo ai = allocate_infos[buffer_var]; + ai.call = scope_stack_.top().call; + allocate_infos[buffer_var] = ai; + } +} + void BufferInfoExtractor::VisitStmt_(const ForNode* op) { - ScopeInfo si{scope_stack_.top().call, scope_stack_.top().func, GetRef(op), + ScopeInfo si{scope_stack_.top().call, + scope_stack_.top().func, + GetRef(op), scope_stack_.top().allocate_nodes, + scope_stack_.top().allocate_const_nodes, scope_stack_.top().initial_stmt_of_the_nested_loops}; if (!scope_stack_.top().initial_stmt_of_the_nested_loops.defined()) { si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_); @@ -355,7 +411,13 @@ void BufferInfoExtractor::VisitExpr_(const VarNode* op) { ScopeInfo& currect_scope_info = scope_stack_.top(); if (currect_scope_info.for_loop.defined()) { - currect_scope_info.allocate_nodes.insert(Downcast(allocate)); + if (allocate->IsInstance()) { + currect_scope_info.allocate_nodes.insert(Downcast(allocate)); + } else if (allocate->IsInstance()) { + currect_scope_info.allocate_const_nodes.insert(Downcast(allocate)); + } else { + LOG(FATAL) << "Handling of " << allocate->GetTypeKey() << " is not implemented"; + } } } StmtExprVisitor::VisitExpr_(op); @@ -401,7 +463,11 @@ void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimF } void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call) { - ScopeInfo si{call, func, scope_stack_.top().for_loop, scope_stack_.top().allocate_nodes, + ScopeInfo si{call, + func, + scope_stack_.top().for_loop, + scope_stack_.top().allocate_nodes, + scope_stack_.top().allocate_const_nodes, scope_stack_.top().initial_stmt_of_the_nested_loops}; call_order_.insert(call); scope_stack_.push(si); @@ -436,10 +502,11 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) { // associated with each BufferNodes. std::vector le_events_timeline; for (const auto& kv1 : buffer_info_map_) { - if (!kv1.second->IsInstance()) { + if (!kv1.second->IsInstance() && !kv1.second->IsInstance()) { continue; } - auto allocate = Downcast(kv1.second); + + auto allocate = Downcast(kv1.second); auto buffer_info = Downcast(kv1.first); ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); @@ -505,6 +572,40 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) { open_set.erase(le_event.buffer_info); } } + + // All ConstantPoolInfo items should have conflicts with each other + // as they will be placed in RO segment and pre-initialized. To achieve this + // first, split buffers to vars (WorkspacePoolInfo items) and constants (ConstantPoolInfo items): + Array buffer_info_vars; + Array buffer_info_constants; + for (const auto& kv : this->buffer_info_map_) { + const auto& stmt = kv.second; + if (stmt->IsInstance()) { + buffer_info_constants.push_back(kv.first); + } else { + buffer_info_vars.push_back(kv.first); + } + } + ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_constants.size()) + << "missing value"; + + Map srch; + // Then intersect constants with each other, as all constants should exist at the same time: + for (const auto& buf : buffer_info_constants) { + srch.Set(buf, buf); + Array conflicts; + std::copy_if(buffer_info_constants.begin(), buffer_info_constants.end(), + std::back_inserter(conflicts), [buf](const auto& b) { return b != buf; }); + buf->conflicts.Assign(conflicts.begin(), conflicts.end()); + } + + // And third, remove all conflicts between constants and vars: + for (const auto& buf : buffer_info_vars) { + Array conflicts; + std::copy_if(buf->conflicts.begin(), buf->conflicts.end(), std::back_inserter(conflicts), + [&srch](const auto& c) { return srch.end() == srch.find(c); }); + buf->conflicts.Assign(conflicts.begin(), conflicts.end()); + } return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size); } diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index e291eaa0519e..52d2f0ef541e 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -48,19 +48,32 @@ class PoolInfoAssigner : public StmtExprMutator { ICHECK(target_host) << "main function does not have a target attr"; WorkspaceMemoryPools workspace_pools = module->GetAttr(tvm::attr::kWorkspaceMemoryPools) - .value_or(WorkspaceMemoryPools({CreateDefaultMemoryPool(module)})); - Array pool_infos = workspace_pools->pools; - for (const PoolInfo& pool_info : pool_infos) { - for (const auto& kv : pool_info->target_access) { - Target target = kv.first; - String target_str = target->str(); - if (target_pool_infos_.find(target_str) == target_pool_infos_.end()) { - target_pool_infos_.Set(target_str, Array()); + .value_or(WorkspaceMemoryPools({CreateDefaultWorkspaceMemoryPool(module)})); + // make default ConstantPoolInfo if no constant and no workspace pool infos supplied + ConstantMemoryPools constant_pools = + module->GetAttr(tvm::attr::kConstantMemoryPools) + .value_or( + module->GetAttr(tvm::attr::kWorkspaceMemoryPools).defined() + ? ConstantMemoryPools() + : ConstantMemoryPools({CreateDefaultConstantMemoryPool(module)})); + auto to_map = [](auto pool_infos) { + Map> pool_map; + for (const PoolInfo& pool_info : pool_infos) { + for (const auto& tgt : pool_info->targets) { + if (pool_map.find(tgt->str()) == pool_map.end()) { + pool_map.Set(tgt->str(), Array()); + } + Array pool_info_arr = pool_map[tgt->str()]; + pool_info_arr.push_back(pool_info); + pool_map.Set(tgt->str(), pool_info_arr); } - Array pool_info_arr = target_pool_infos_[target_str]; - pool_info_arr.push_back(pool_info); - target_pool_infos_.Set(target_str, pool_info_arr); } + return pool_map; + }; + + target_pool_infos_ = to_map(workspace_pools->pools); + if (constant_pools.defined()) { + target_const_pool_infos_ = to_map(constant_pools->pools); } mod_ = module->ShallowCopy(); } @@ -69,14 +82,23 @@ class PoolInfoAssigner : public StmtExprMutator { private: Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const AllocateConstNode* op) override; IRModule mod_; Map> target_pool_infos_; + Map> target_const_pool_infos_; PrimFunc func_; - PoolInfo CreateDefaultMemoryPool(const IRModule& module); + WorkspacePoolInfo CreateDefaultWorkspaceMemoryPool(const IRModule& module); + ConstantPoolInfo CreateDefaultConstantMemoryPool(const IRModule& module) { + auto p = CreateDefaultWorkspaceMemoryPool(module); + return ConstantPoolInfo( + "global_const_workspace", {p->targets}, {}, + PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, + kUnknownWriteBandwidth, 0, 0, {p->target_burst_bytes}, Bool(true))); + } }; -PoolInfo PoolInfoAssigner::CreateDefaultMemoryPool(const tvm::IRModule& module) { +WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) { VLOG(1) << "Creating default memory pool for:" << std::endl << PrettyPrint(module); Map target_access; tir::PrimFunc tir_main_func = @@ -87,9 +109,23 @@ PoolInfo PoolInfoAssigner::CreateDefaultMemoryPool(const tvm::IRModule& module) Optional target = func->GetAttr(tvm::attr::kTarget); target_access.Set(target.value_or(target_host), kTargetPoolReadWriteAccess); } - return PoolInfo("global_workspace", target_access, kUnrestrictedPoolSizeHint, - kUnknownClockFrequency, kUnknownReadBandwidth, kUnknownWriteBandwidth, 0, 0, {}, - Bool(true)); + Array targets; + for (const auto& kv : target_access) { + bool exist = false; + // Exclude targets with the same string representation + for (const auto& t : targets) { + if (t->str() == kv.first->str()) { + exist = true; + } + } + if (!exist) { + targets.push_back(kv.first); + } + } + return WorkspacePoolInfo( + "global_workspace", targets, + PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, + kUnknownWriteBandwidth, 0, 0, {{target_host, 1}}, Bool(true))); } Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { @@ -97,6 +133,8 @@ Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; Map annotations = Map(op->annotations); if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { + ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0) + << "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_); annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]); } Stmt body = VisitStmt(op->body); @@ -105,6 +143,23 @@ Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { return allocate; } +Stmt PoolInfoAssigner::VisitStmt_(const AllocateConstNode* op) { + if (!target_const_pool_infos_.size()) { + return StmtExprMutator::VisitStmt_(op); + } + Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); + ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; + Map annotations = Map(op->annotations); + if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { + annotations.Set(kPoolCandidatesAllocateAttr, target_const_pool_infos_[tgt.value()->str()]); + annotations.Set(kTargetPoolReadOnlyAccess, Integer(1)); + } + Stmt body = VisitStmt(op->body); + auto allocate_const = + AllocateConst(op->buffer_var, op->dtype, op->extents, op->data, body, annotations); + return allocate_const; +} + IRModule PoolInfoAssigner::operator()() { for (const auto& kv : mod_->functions) { GlobalVar gv = kv.first; diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 1161962f1287..24a55190d326 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -53,14 +54,20 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) { module_ = module->ShallowCopy(); for (const auto& kv : pool_allocations) { - // TODO(@manupa-arm): add AllocateConstNode when it is available - ICHECK(kv.first->IsInstance()); - Allocate allocate_node = Downcast(kv.first); + size_t extent_size = -1; + if (kv.first->IsInstance()) { + Allocate allocate_node = Downcast(kv.first); + extent_size = CalculateExtentsSize(allocate_node.operator->()); + } else if (kv.first->IsInstance()) { + AllocateConst allocate_const_node = Downcast(kv.first); + extent_size = CalculateExtentsSize(allocate_const_node.operator->()); + } else { + ICHECK(false) << "Not supported node type " << kv.first->GetTypeKey(); + } PoolAllocation pool_allocation = kv.second; PoolInfo pool_info = pool_allocation->pool_info; int byte_pool_offset = pool_allocation->byte_offset->value; - int required_pool_size_for_allocation = - byte_pool_offset + static_cast(CalculateExtentsSize(allocate_node.operator->())); + int required_pool_size_for_allocation = byte_pool_offset + extent_size; if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { all_pools_sizes_[pool_info] = required_pool_size_for_allocation; } else { @@ -92,6 +99,8 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const AllocateConstNode* op) override; + LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& buffer_var, const Stmt& body); /*! \brief This is a structure where the modified function * signature is kept while body of the function is mutated */ @@ -121,7 +130,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief This is a helper to append the pool args to * the callsite of the function. */ - Array AppendPoolParamsToArgs(Array args, const PrimFunc& func); + Array AppendPoolParamsToArgs(Array args, bool has_device_context); /*! \brief Some arguments that used to be Allocate nodes * should be replaced by Let nodes in the pass that loads * the space from a pool variable. @@ -159,12 +168,17 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { * be tracked separately. */ Map original_buf_to_let_buf_; + + Map signature_has_device_context_; /*! \brief A counter to give references to pools a reproducible unique set of names */ int pool_var_count_ = 0; /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */ bool emit_tvmscript_printable_ = false; /*! \brief A counter to give references to pools a reproducible unique set of names */ std::unordered_set visited_primfuncs; + + Map> pool_initializations_; + void AppdendConstInitializationData(ScopeInfo si); }; Optional PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& func) { @@ -239,10 +253,11 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( } Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(Array args, - const PrimFunc& func) { + bool has_device_context) { Array new_args; PrimExpr resource_handle_arg; - if (args.size() == func->params.size() + 1) { + // name, params...params[, context] + if (has_device_context) { resource_handle_arg = args.back(); args.pop_back(); } @@ -283,9 +298,18 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { module_->Lookup(func_name)->IsInstance()) { GlobalVar gv = module_->GetGlobalVar(func_name); PrimFunc func = Downcast(module_->Lookup(gv)); + + if (!signature_has_device_context_.count(func_name)) { + if (op->args.size() == func->params.size() + 2) { + signature_has_device_context_.Set(func_name, Bool(true)); + } else { + signature_has_device_context_.Set(func_name, Bool(false)); + } + } + PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); module_->Update(gv, prim_func); - new_args = AppendPoolParamsToArgs(op->args, prim_func); + new_args = AppendPoolParamsToArgs(op->args, signature_has_device_context_[func_name]); new_args = ReplaceAllocateArgsWithLetArgs(new_args); } else { new_args = ReplaceAllocateArgsWithLetArgs(op->args); @@ -293,36 +317,60 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { return Call(op->dtype, op->op, new_args); } if (op->op->IsInstance()) { + String func_name = Downcast(op->args[0])->value; PrimFunc func = Downcast(op->op); PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); - Array new_args = AppendPoolParamsToArgs(op->args, prim_func); + Array new_args = + AppendPoolParamsToArgs(op->args, signature_has_device_context_[func_name]); new_args = ReplaceAllocateArgsWithLetArgs(new_args); return Call(op->dtype, prim_func, new_args); } return StmtExprMutator::VisitExpr_(op); } +LetStmt PoolAllocationToOffsetConverter::ToLetStmt(const PoolAllocation& pool_allocation, + const Var& buffer_var, const Stmt& body) { + ScopeInfo scope_info = scope_stack.top(); + Var param = scope_info.pools_to_params[pool_allocation->pool_info]; + BufferLoad load_node = BufferLoad(scope_info.buffer_map[param], {pool_allocation->byte_offset}); + Call address_of_load = Call(DataType::Handle(), builtin::address_of(), {load_node}); + + Type let_var_type = buffer_var->type_annotation; + if (emit_tvmscript_printable_) { + // Strip the storage_scope from the variable type, as TVMScript + // doesn't parsethe scoped pointers (e.g. ``T.Ptr[global T.int32]``) + // correctly. + let_var_type = PointerType(Downcast(let_var_type)->element_type); + } + Var let_var(buffer_var->name_hint + "_let", let_var_type); + allocate_var_to_let_var_.Set(buffer_var, let_var); + Stmt new_body = VisitStmt(body); + allocate_var_to_let_var_.erase(buffer_var); + return LetStmt(let_var, address_of_load, new_body); +} + Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { if (pool_allocations_.count(GetRef(op))) { - ScopeInfo scope_info = scope_stack.top(); - PoolAllocation pool_allocation = pool_allocations_[GetRef(op)]; - Var param = scope_info.pools_to_params[pool_allocation->pool_info]; - Buffer buffer_var = scope_info.buffer_map[param]; - BufferLoad load_node = BufferLoad(buffer_var, {pool_allocation->byte_offset}); - Call address_of_load = Call(DataType::Handle(), builtin::address_of(), {load_node}); - - Type let_var_type = op->buffer_var->type_annotation; - if (emit_tvmscript_printable_) { - // Strip the storage_scope from the variable type, as TVMScript - // doesn't parsethe scoped pointers (e.g. ``T.Ptr[global T.int32]``) - // correctly. - let_var_type = PointerType(Downcast(let_var_type)->element_type); + return ToLetStmt(pool_allocations_[GetRef(op)], op->buffer_var, op->body); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateConstNode* op) { + if (pool_allocations_.count(GetRef(op))) { + const auto& result = ToLetStmt(pool_allocations_[GetRef(op)], op->buffer_var, op->body); + + PoolInfo pool_info = pool_allocations_[GetRef(op)]->pool_info; + if (pool_initializations_.find(pool_info) == pool_initializations_.end()) { + pool_initializations_.Set(pool_info, {}); } - Var let_var(op->buffer_var->name_hint + "_let", let_var_type); - allocate_var_to_let_var_.Set(op->buffer_var, let_var); - Stmt new_body = VisitStmt(op->body); - allocate_var_to_let_var_.erase(op->buffer_var); - return LetStmt(let_var, address_of_load, new_body); + + auto consts = pool_initializations_[pool_info]; + consts.push_back({result->var->name_hint, pool_allocations_[GetRef(op)]->byte_offset, + op->data.value()}); + + pool_initializations_.Set(pool_info, consts); + return result; } return StmtExprMutator::VisitStmt_(op); } @@ -369,6 +417,17 @@ Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) { return remapped; } +void PoolAllocationToOffsetConverter::AppdendConstInitializationData( + PoolAllocationToOffsetConverter::ScopeInfo si) { + for (AllocatedPoolInfo api : si.allocated_pool_params) { + const auto& it = pool_initializations_.find(api->pool_info); + if (it != pool_initializations_.end()) { + auto* pi = const_cast(api->pool_info.as()); + pi->constant_info_array = (*it).second; + } + } +} + IRModule PoolAllocationToOffsetConverter::operator()() { GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main); PrimFunc main_func = Downcast(module_->Lookup(gv)); @@ -376,6 +435,7 @@ IRModule PoolAllocationToOffsetConverter::operator()() { this->scope_stack.push(si); Stmt main_func_body = this->VisitStmt(main_func->body); this->scope_stack.pop(); + AppdendConstInitializationData(si); // We dont need attrs of PrimFunc that might include non printable attrs such as target // for unit tests where emit_tvmscript_printable_ is to be used. if (!emit_tvmscript_printable_) { diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc index ae915473906b..d7eb0f3a7e64 100644 --- a/src/tir/usmp/unified_static_memory_planner.cc +++ b/src/tir/usmp/unified_static_memory_planner.cc @@ -33,6 +33,7 @@ #include #include +#include #include namespace tvm { @@ -62,13 +63,15 @@ IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) { PrimFunc main_func = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module); Array buffer_info_arr = - CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts); + ConvertToArrayOfBufferInfo(buffer_info_analysis->buffer_info_stmts); CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo << " is not defined. Please define it in the above algorithms map."; Map buffer_info_pool_allocations = algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure); + Map stmt_pool_allocations = AssignStmtPoolAllocations( buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations); + module = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(module); if (use_workspace_io) { Map io_pool_allocations = diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index d02f0d8d33b3..6f95c7cbaf66 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -74,7 +74,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "name_hint=" << node->name_hint << ",\n size_bytes=" << node->size_bytes << ",\n pool_candidates=" << node->pool_candidates << ",\n alignment=" << node->alignment << ",\n kind=" << toString[node->kind] - << ")"; + << ",\n conflicts=" << node->conflicts.size() << ")"; }); BufferInfoAnalysis::BufferInfoAnalysis(Map buffer_info_stmts, @@ -145,7 +145,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -Array CreateArrayBufferInfo(const Map& buffer_info_map) { +Array ConvertToArrayOfBufferInfo(const Map& buffer_info_map) { Array ret; for (const auto& kv : buffer_info_map) { auto buffer_info = kv.first; @@ -180,10 +180,10 @@ Map GetIOPoolAllocations( return io_tensor_name_to_pool_allocation; } -Integer CalculateExtentsSize(const AllocateNode* op) { - size_t element_size_bytes = op->dtype.bytes(); +static Integer CalculateExtentsSize(const DataType& dtype, const Array& extents) { + size_t element_size_bytes = dtype.bytes(); size_t num_elements = 1; - for (const auto& ext : op->extents) { + for (const auto& ext : extents) { if (ext->IsInstance()) { num_elements *= Downcast(ext)->value; } else { @@ -194,11 +194,21 @@ Integer CalculateExtentsSize(const AllocateNode* op) { return Integer(num_elements * element_size_bytes); } +Integer CalculateExtentsSize(const AllocateNode* op) { + return CalculateExtentsSize(op->dtype, op->extents); +} + +Integer CalculateExtentsSize(const AllocateConstNode* op) { + return CalculateExtentsSize(op->dtype, op->extents); +} + class ModuleWorkspaceSizeCalculator : public StmtExprVisitor { public: explicit ModuleWorkspaceSizeCalculator(const IRModule& module) : mod_(module) { for (const auto& gv_func : mod_->functions) { - functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + if ((gv_func.second)->IsInstance()) { + functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } } main_func_ = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); ICHECK(main_func_.defined()) << "main function is not in the module"; @@ -256,7 +266,7 @@ Integer CalculateModuleWorkspaceSize(const IRModule& mod) { TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") .set_body_typed([](Map buffer_info_map) { - return (CreateArrayBufferInfo(buffer_info_map)); + return (ConvertToArrayOfBufferInfo(buffer_info_map)); }); TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations").set_body_typed(AssignStmtPoolAllocations); diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index f8ce614b24cf..5c4ba9a528ca 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -38,10 +39,21 @@ const struct TVMTensorInfo kNormalOutputs[1] = { {"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}}; const int64_t kNormalPool1Shape[3] = {3, 8, 8}; -const struct TVMTensorInfo kNormalPools[1] = {{"pool1", kNormalPool1Shape, 3, DLDataType{3, 4, 7}}}; +const struct TVMTensorInfo kNormalWorkspacePools[1] = { + {"workspace_pool1", kNormalPool1Shape, 3, DLDataType{3, 4, 7}}}; +const struct TVMConstantInfo kNormalConstantPools[1] = {{"constant_pool1", 0, 0, {}}}; const struct TVMMetadata kNormal = { - TVM_METADATA_VERSION, kNormalInputs, 2, kNormalOutputs, 1, kNormalPools, 1, "default", + TVM_METADATA_VERSION, + kNormalInputs, + 2, + kNormalOutputs, + 1, + kNormalWorkspacePools, + 1, + kNormalConstantPools, + 1, + "default", }; } // namespace @@ -61,6 +73,7 @@ using ::tvm::runtime::Array; using ::tvm::runtime::Downcast; using ::tvm::runtime::ObjectRef; +using ::tvm::runtime::metadata::ConstantInfoMetadata; using ::tvm::runtime::metadata::Metadata; using ::tvm::runtime::metadata::MetadataArray; using ::tvm::runtime::metadata::MetadataKind; @@ -93,13 +106,13 @@ TEST(Metadata, ParseStruct) { EXPECT_THAT(output1->shape(), ElementsAre(3, 8, 8)); EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5}))); - auto pools = md->pools(); + auto pools = md->workspace_pools(); EXPECT_THAT(pools.size(), Eq(1)); - auto pool1 = pools[0]; - EXPECT_THAT(pool1->name(), Eq("pool1")); - EXPECT_THAT(pool1->shape(), ElementsAre(3, 8, 8)); - EXPECT_THAT(pool1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 7}))); + auto workspace_pool1 = pools[0]; + EXPECT_THAT(workspace_pool1->name(), Eq("workspace_pool1")); + EXPECT_THAT(workspace_pool1->shape(), ElementsAre(3, 8, 8)); + EXPECT_THAT(workspace_pool1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 7}))); EXPECT_THAT(md->mod_name(), Eq("default")); } @@ -158,8 +171,9 @@ TEST(Metadata, Visitor) { ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); EXPECT_THAT(v.keys, ElementsAre(StrEq("version"), StrEq("inputs"), StrEq("num_inputs"), - StrEq("outputs"), StrEq("num_outputs"), StrEq("pools"), - StrEq("num_pools"), StrEq("mod_name"))); + StrEq("outputs"), StrEq("num_outputs"), StrEq("workspace_pools"), + StrEq("num_workspace_pools"), StrEq("constant_pools"), + StrEq("num_constant_pools"), StrEq("mod_name"))); EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); @@ -196,14 +210,24 @@ TEST(Metadata, Visitor) { auto pool_array = Downcast(v.values[5]); EXPECT_THAT(pool_array->kind, Eq(MetadataKind::kMetadata)); EXPECT_THAT(pool_array->type_key, StrEq("metadata.TensorInfoNode")); - auto pool1 = Downcast(pool_array->array[0]); + auto workspace_pool1 = Downcast(pool_array->array[0]); - EXPECT_THAT(pool1->name(), Eq("pool1")); + EXPECT_THAT(workspace_pool1->name(), Eq("workspace_pool1")); - auto num_pools = Downcast(v.values[6]); - EXPECT_THAT(num_pools->value, Eq(1)); + auto num_workspace_pools = Downcast(v.values[6]); + EXPECT_THAT(num_workspace_pools->value, Eq(1)); - auto mod_name = Downcast(v.values[7]); + auto consts_array = Downcast(v.values[7]); + EXPECT_THAT(consts_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(consts_array->type_key, StrEq("metadata.ConstantInfoNode")); + auto consts1 = Downcast(consts_array->array[0]); + + EXPECT_THAT(consts1->name_hint(), Eq("constant_pool1")); + + auto num_consts = Downcast(v.values[8]); + EXPECT_THAT(num_consts->value, Eq(1)); + + auto mod_name = Downcast(v.values[9]); EXPECT_THAT(mod_name, Eq("default")); } @@ -224,8 +248,11 @@ TEST(Metadata, InMemory) { tvm::runtime::DataType(DLDataType{3, 4, 5})))}), std::vector( {TensorInfo(make_object( - tvm::String("Pool1"), std::vector{5, 10, 10}, + tvm::String("Workspace_Pool1"), std::vector{5, 10, 10}, tvm::runtime::DataType(DLDataType{3, 4, 7})))}), + std::vector({tvm::ConstantInfo( + "Constant_Pool1", 64, + tvm::runtime::NDArray::Empty({64}, tvm::runtime::DataType::Int(64), {kDLCPU}))}), "default")); auto md_data = md->data(); @@ -253,13 +280,17 @@ TEST(Metadata, InMemory) { EXPECT_THAT(tvm::runtime::DataType(output0->dtype), Eq(tvm::runtime::DataType(DLDataType({3, 4, 5})))); - auto pool0 = &md_data->pools[0]; - EXPECT_THAT(pool0->name, StrEq("Pool1")); - EXPECT_THAT(std::vector(pool0->shape, pool0->shape + pool0->num_shape), + auto workspace_pool0 = &md_data->workspace_pools[0]; + EXPECT_THAT(workspace_pool0->name, StrEq("Workspace_Pool1")); + EXPECT_THAT(std::vector(workspace_pool0->shape, + workspace_pool0->shape + workspace_pool0->num_shape), ElementsAre(5, 10, 10)); - EXPECT_THAT(tvm::runtime::DataType(pool0->dtype), + EXPECT_THAT(tvm::runtime::DataType(workspace_pool0->dtype), Eq(tvm::runtime::DataType(DLDataType({3, 4, 7})))); + auto constant_pool0 = &md_data->constant_pools[0]; + EXPECT_THAT(constant_pool0->name_hint, StrEq("Constant_Pool1")); + EXPECT_THAT(md_data->mod_name, StrEq("default")); } @@ -270,7 +301,7 @@ TEST(Metadata, ZeroElementLists) { {TensorInfo(make_object( tvm::String("Output1"), std::vector{}, tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({}), "default")); + std::vector({}), std::vector({}), "default")); EXPECT_THAT(md->data()->num_inputs, Eq(0)); EXPECT_THAT(md->inputs().size(), Eq(0)); @@ -282,9 +313,9 @@ TEST(Metadata, ZeroElementLists) { EXPECT_THAT(md->outputs()[0]->shape().size(), Eq(0)); EXPECT_THAT(md->outputs()[0]->shape(), ElementsAre()); - EXPECT_THAT(md->pools().size(), Eq(0)); - EXPECT_THAT(md->num_pools(), Eq(0)); - EXPECT_THAT(md->pools(), ElementsAre()); + EXPECT_THAT(md->workspace_pools().size(), Eq(0)); + EXPECT_THAT(md->num_workspace_pools(), Eq(0)); + EXPECT_THAT(md->workspace_pools(), ElementsAre()); } TEST(MetadataArray, GetElementCStructName) { @@ -323,8 +354,9 @@ TEST(DiscoverArraysVisitor, DiscoverArrays) { DiscoveredNameEq("kTvmgenMetadata_inputs"), DiscoveredNameEq("kTvmgenMetadata_outputs_0_shape"), DiscoveredNameEq("kTvmgenMetadata_outputs"), - DiscoveredNameEq("kTvmgenMetadata_pools_0_shape"), - DiscoveredNameEq("kTvmgenMetadata_pools")})); + DiscoveredNameEq("kTvmgenMetadata_workspace_pools_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_workspace_pools"), + DiscoveredNameEq("kTvmgenMetadata_constant_pools")})); } // In Debug builds the _type_key is no longer inlined but also has no @@ -381,5 +413,13 @@ TEST(DiscoverComplexTypesVisitor, DiscoverComplexTypes) { Metadata md = Metadata(&kNormal); visitor.Discover(md); - EXPECT_THAT(q, ElementsAre(TVMObjectIsInstance(), TVMObjectIsInstance())); + EXPECT_THAT( + q, ElementsAre(TVMObjectIsInstance(), TVMObjectIsInstance(), + TVMObjectIsInstance())); +} + +TEST(Metadata, TVMConstantInfo) { + std::vector q; + std::unique_ptr ci{new struct TVMConstantInfo[10]}; + EXPECT_TRUE(ci.get() != nullptr); } diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 4814a1c7e7db..103989fc779d 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -128,8 +128,8 @@ TEST(Relay, BuildModule) { Array targets = {llvm_tgt}; auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, Executor::Create("graph"), Runtime::Create("cpp"), - WorkspaceMemoryPools(), ""); + build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), + WorkspaceMemoryPools(), ConstantMemoryPools(), ""); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run diff --git a/tests/cpp/runtime_test.cc b/tests/cpp/runtime_test.cc index 33f44f4f3e54..be81ded5d78b 100644 --- a/tests/cpp/runtime_test.cc +++ b/tests/cpp/runtime_test.cc @@ -114,8 +114,8 @@ TEST(Runtime, ZeroCopy) { Array targets = {llvm_tgt}; auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, Executor::Create("graph"), Runtime::Create("cpp"), - WorkspaceMemoryPools(), ""); + build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), + WorkspaceMemoryPools(), ConstantMemoryPools(), ""); // create graph executor std::string json = json_f(); tvm::runtime::Module mod = mod_f(); diff --git a/tests/cpp/target/source/interface_c_test.cc b/tests/cpp/target/source/interface_c_test.cc index d578c79255e6..4fb9df3d0557 100644 --- a/tests/cpp/target/source/interface_c_test.cc +++ b/tests/cpp/target/source/interface_c_test.cc @@ -116,7 +116,7 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) { << " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n" << ");\n"; - PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info = tir::usmp::AllocatedPoolInfo(pool_info, 100000); runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, @@ -143,7 +143,7 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePoolsAndDevices) { << " struct tvmgen_ultimate_cat_spotter_devices* devices\n" << ");\n"; - PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info = tir::usmp::AllocatedPoolInfo(pool_info, 100000); runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, @@ -183,7 +183,7 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceIO) { << " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n" << ");\n"; - PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info = tir::usmp::AllocatedPoolInfo(pool_info, 100000); tir::usmp::PoolAllocation pool_allocation_input{pool_info, 1000}; @@ -384,7 +384,7 @@ TEST(InterfaceAPI, ContainsWorkspaceSize) { } TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) { - PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info = tir::usmp::AllocatedPoolInfo(pool_info, 100000); @@ -413,10 +413,10 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) { } TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) { - PoolInfo pool_info1 = PoolInfo("my_memory_pool_1", {}); + PoolInfo pool_info1 = WorkspacePoolInfo("my_memory_pool_1", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info1 = tir::usmp::AllocatedPoolInfo(pool_info1, 100000); - PoolInfo pool_info2 = PoolInfo("my_memory_pool_2", {}); + PoolInfo pool_info2 = WorkspacePoolInfo("my_memory_pool_2", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info2 = tir::usmp::AllocatedPoolInfo(pool_info2, 200000); @@ -454,7 +454,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) { } TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) { - PoolInfo pool_info = PoolInfo("my_memory_pool+1", {}); + PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool+1", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info = tir::usmp::AllocatedPoolInfo(pool_info, 100000); @@ -483,10 +483,10 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) { } TEST(InterfaceAPI, ContainsWorkspacePoolStructClash) { - PoolInfo pool_info1 = PoolInfo("my_memory_pool+", {}); + PoolInfo pool_info1 = WorkspacePoolInfo("my_memory_pool+", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info1 = tir::usmp::AllocatedPoolInfo(pool_info1, 100000); - PoolInfo pool_info2 = PoolInfo("my_memory_pool-", {}); + PoolInfo pool_info2 = WorkspacePoolInfo("my_memory_pool-", {}); tir::usmp::AllocatedPoolInfo allocated_pool_info2 = tir::usmp::AllocatedPoolInfo(pool_info2, 200000); diff --git a/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py index 8a0d51d2ae0c..0bfc64fe041d 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py +++ b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py @@ -27,7 +27,7 @@ from tvm.micro import model_library_format as mlf from tvm.relay.op.contrib.ethosu import partition_for_ethosu import tvm -from tvm import WorkspaceMemoryPools, PoolInfo +from tvm import WorkspaceMemoryPools, WorkspacePoolInfo, PoolInfoProperties from .. import infra @@ -63,13 +63,15 @@ def _get_ethosu_workspace_size( workspace_memory_pools = WorkspaceMemoryPools( [ - PoolInfo( + WorkspacePoolInfo( "SRAM", - {target: PoolInfo.READ_WRITE_ACCESS, ethosu_target: PoolInfo.READ_WRITE_ACCESS}, - size_hint_bytes=pool_size, - read_bandwidth_bytes_per_cycle=16, - write_bandwidth_bytes_per_cycle=16, - target_burst_bytes={ethosu_target: 1}, + [target, ethosu_target], + PoolInfoProperties( + size_hint_bytes=pool_size, + read_bandwidth_bytes_per_cycle=16, + write_bandwidth_bytes_per_cycle=16, + target_burst_bytes={ethosu_target: 1}, + ), ), ] ) diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 1f999781e3b1..315c2367c82a 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -45,7 +45,7 @@ from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.backend.contrib.ethosu import preprocess import tvm.relay.testing.tf as tf_testing -from tvm import WorkspaceMemoryPools, PoolInfo +from tvm import WorkspaceMemoryPools, WorkspacePoolInfo, PoolInfoProperties from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.testing.aot import ( @@ -334,16 +334,15 @@ def compare_ethosu_with_reference( ethosu_target = tvm.target.Target("ethos-u") workspace_pools = WorkspaceMemoryPools( [ - PoolInfo( + WorkspacePoolInfo( pool_name, - { - host_target: PoolInfo.READ_WRITE_ACCESS, - ethosu_target: PoolInfo.READ_WRITE_ACCESS, - }, - size_hint_bytes=2400000, - read_bandwidth_bytes_per_cycle=16, - write_bandwidth_bytes_per_cycle=16, - target_burst_bytes={ethosu_target: 1}, + [host_target, ethosu_target], + PoolInfoProperties( + size_hint_bytes=2400000, + read_bandwidth_bytes_per_cycle=16, + write_bandwidth_bytes_per_cycle=16, + target_burst_bytes={ethosu_target: 1}, + ), ) ] ) diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index ca7a213be58b..c4081f911a5f 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -23,7 +23,7 @@ from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.micro import model_library_format as mlf -from tvm import WorkspaceMemoryPools, PoolInfo +from tvm import WorkspaceMemoryPools, WorkspacePoolInfo, PoolInfoProperties import tvm from tvm.testing.aot import convert_to_relay @@ -105,16 +105,15 @@ def test_networks_with_usmp_and_cascader_wo_striping(accel_type, model_url, work ethosu_target = tvm.target.Target("ethos-u") workspace_pools = WorkspaceMemoryPools( [ - PoolInfo( + WorkspacePoolInfo( pool_name, - { - host_target: PoolInfo.READ_WRITE_ACCESS, - ethosu_target: PoolInfo.READ_WRITE_ACCESS, - }, - size_hint_bytes=2400000, - read_bandwidth_bytes_per_cycle=16, - write_bandwidth_bytes_per_cycle=16, - target_burst_bytes={ethosu_target: 1}, + [host_target, ethosu_target], + PoolInfoProperties( + size_hint_bytes=2400000, + read_bandwidth_bytes_per_cycle=16, + write_bandwidth_bytes_per_cycle=16, + target_burst_bytes={ethosu_target: 1}, + ), ) ] ) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 742b681ae619..3f641c995652 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -87,31 +87,33 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} - weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + weight_data = np.random.randint(1, 255, shape_dict["weight"]).astype(type_dict["weight"]) input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) - params = {"weight": weight_data} inputs = {"data": input_data} ref_outputs = generate_ref_data(ir_mod, inputs, params) with tvm.transform.PassContext( - opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp} + opt_level=3, + config={ + "tir.disable_vectorize": True, + "tir.usmp.enable": enable_usmp, + }, ): mod = tvm.relay.build( ir_mod, params=params, target=target_kind, - executor=backend.Executor("aot", {"interface-api": "packed"}), + executor=backend.Executor("aot", {"interface-api": "packed", "unpacked-api": False}), ) - temp_dir = tvm.contrib.utils.TempDirectory() test_so_path = temp_dir / "test.so" - mod.export_library(test_so_path, cc="gcc", options=["-std=c11"]) + mod.export_library(test_so_path, cc="gcc", options=["-std=c11", "-g3", "-O0"]) loaded_mod = tvm.runtime.load_module(test_so_path) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) runner.run() - assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() + assert (runner.get_output(0).numpy() == list(ref_outputs.values())[0]).all() @pytest.mark.parametrize("enable_usmp", [True, False]) @@ -136,7 +138,7 @@ def test_mobilenet(enable_usmp, target_kind): temp_dir = tvm.contrib.utils.TempDirectory() test_so_path = temp_dir / "test.so" - mod.export_library(test_so_path, cc="gcc", options=["-std=c11"]) + mod.export_library(test_so_path, cc="c++", options=["-std=gnu++14", "-g3", "-O0"]) loaded_mod = tvm.runtime.load_module(test_so_path) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) @@ -188,7 +190,7 @@ def test_pass_wrong_device_arg(): temp_dir = tvm.contrib.utils.TempDirectory() test_so_path = temp_dir / "test.so" - mod.export_library(test_so_path, cc="gcc", options=["-std=c11"]) + mod.export_library(test_so_path, cc="gcc", options=["-std=c11", "-g3", "-O0"]) loaded_mod = tvm.runtime.load_module(test_so_path) with pytest.raises(tvm.TVMError) as error: diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 4205b458177c..a2f9ee5eb0f7 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -27,7 +27,7 @@ from tvm.relay import transform from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.backend import Executor, Runtime -from tvm import WorkspaceMemoryPools, PoolInfo +from tvm import WorkspaceMemoryPools, WorkspacePoolInfo, PoolInfoProperties from tvm.micro import model_library_format as mlf from tvm.micro.testing.aot_test_utils import parametrize_aot_options from tvm.testing.aot import ( @@ -48,15 +48,68 @@ def _check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module): ), "This is failing because USMP was unable to plan for every tir.allocate node." +# U1 test case +@parametrize_aot_options +def test_synthetic(interface_api, use_unpacked_api, test_runner): + """ + Simple U1 usecase test + """ + mod, params = tvm.relay.testing.synthetic.get_workload() + main_func = mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + params = {} + for name, _ in shape_dict.items(): + if name != "data": + params[name] = np.ones(shape_dict[name]).astype(type_dict[name]) + + inputs = {"data": input_data} + output_list = generate_ref_data(mod, inputs, params) + config = ( + { + "tir.disable_vectorize": True, + "tir.disable_storage_rewrite": True, + "tir.usmp.enable": True, + "tir.usmp.algorithm": "greedy_by_conflicts", + }, + ) + + test_runner = AOTTestRunner( + makefile=test_runner.makefile, + prologue=test_runner.prologue, + epilogue=test_runner.epilogue, + includes=test_runner.includes, + parameters=test_runner.parameters, + pass_config={**test_runner.pass_config}, + ) + test_runner.pass_config.update(*config) + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + ) + + @pytest.mark.parametrize( - "workspace_byte_alignment,main_workspace_size", + "workspace_byte_alignment,constant_byte_alignment,main_workspace_size,main_constant_size", [ - (8, 17280), - (16, 17280), - (256, 17792), + (8, 8, 17280, 948), + (16, 8, 17280, 948), + (256, 8, 17792, 948), + (8, 16, 17280, 956), + (16, 16, 17280, 956), + (256, 16, 17792, 956), + (8, 256, 17280, 1804), + (16, 256, 17280, 1804), + (256, 256, 17792, 1804), ], ) -def test_memory_planning(workspace_byte_alignment, main_workspace_size): +def test_memory_planning( + workspace_byte_alignment, constant_byte_alignment, main_workspace_size, main_constant_size +): """Checks calculated workspace against known values""" mod, params = tvm.relay.testing.synthetic.get_workload() target = "c" @@ -65,6 +118,7 @@ def test_memory_planning(workspace_byte_alignment, main_workspace_size): "aot", { "workspace-byte-alignment": workspace_byte_alignment, + "constant-byte-alignment": constant_byte_alignment, }, ) with tvm.transform.PassContext( @@ -79,8 +133,10 @@ def test_memory_planning(workspace_byte_alignment, main_workspace_size): lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) # The workspace_size dictionary will have an entry for both the 'primitive' and 'host' # targets, though both are identical. - for size in lib.function_metadata["__tvm_main__"].workspace_sizes.values(): - assert size == main_workspace_size + assert ( + sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size + ) + assert sum(lib.function_metadata["__tvm_main__"].constant_sizes.values()) == main_constant_size @parametrize_aot_options @@ -212,14 +268,14 @@ def test_byoc_microtvm(merge_compiler_regions): @pytest.mark.parametrize( - "model_url, usmp_algo, workspace_size,", + "model_url, usmp_algo, workspace_size, constant_size", [ - (MOBILENET_V1_URL, "greedy_by_size", 4845696), - (MOBILENET_V1_URL, "greedy_by_conflicts", 4444288), - (MOBILENET_V1_URL, "hill_climb", 3240064), + (MOBILENET_V1_URL, "greedy_by_size", 4845696, 8468008), + (MOBILENET_V1_URL, "greedy_by_conflicts", 4444288, 8468008), + (MOBILENET_V1_URL, "hill_climb", 3240064, 8468008), ], ) -def test_tflite_model_u1_usecase(model_url, usmp_algo, workspace_size): +def test_tflite_model_u1_usecase(model_url, usmp_algo, workspace_size, constant_size): """ This checks for ML models and the memory used by them when using USMP with different algorithms @@ -256,11 +312,19 @@ def test_tflite_model_u1_usecase(model_url, usmp_algo, workspace_size): compiled_test_mods[0].executor_factory.function_metadata ) assert mlf_memory_map["main"][0]["workspace_size_bytes"] == workspace_size + assert mlf_memory_map["main"][0]["constants_size_bytes"] == constant_size # That should match to workspace size that will be codegen'd to the entry point. - allocated_pool_info = list( - dict(compiled_test_mods[0].executor_factory.executor_codegen_metadata.pool_inputs).values() - )[0] - assert allocated_pool_info.allocated_size == workspace_size + allocated_pool_info_size = sum( + [ + _.allocated_size + for _ in list( + dict( + compiled_test_mods[0].executor_factory.executor_codegen_metadata.pool_inputs + ).values() + ) + ] + ) + assert allocated_pool_info_size == workspace_size + constant_size run_and_check( models=compiled_test_mods, @@ -300,9 +364,7 @@ def test_tflite_model_u3_usecase_single_external_pool(model_url, usmp_algo): pool_name = "my_memory_pool" target = tvm.target.Target("c") - workspace_memory_pools = WorkspaceMemoryPools( - [PoolInfo(pool_name, {target: PoolInfo.READ_WRITE_ACCESS})] - ) + workspace_memory_pools = WorkspaceMemoryPools([WorkspacePoolInfo(pool_name, [target])]) test_runner = AOTTestRunner( pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": usmp_algo}, prologue=f""" @@ -355,10 +417,10 @@ def test_tflite_model_u3_usecase_two_external_pools(model_url, usmp_algo): target = tvm.target.Target("c") workspace_memory_pools = WorkspaceMemoryPools( [ - PoolInfo( - "my_memory_pool_1", {target: PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=2500000 + WorkspacePoolInfo( + "my_memory_pool_1", [target], PoolInfoProperties(size_hint_bytes=2500000) ), - PoolInfo("my_memory_pool_2", {target: PoolInfo.READ_WRITE_ACCESS}), + WorkspacePoolInfo("my_memory_pool_2", [target]), ] ) test_runner = AOTTestRunner( @@ -413,9 +475,7 @@ def test_two_models_with_a_single_external_pool(model_urls, usmp_algo): interface_api = "c" target = tvm.target.Target("c") - workspace_memory_pools = WorkspaceMemoryPools( - [PoolInfo("my_memory_pool", {target: PoolInfo.READ_WRITE_ACCESS})] - ) + workspace_memory_pools = WorkspaceMemoryPools([WorkspacePoolInfo("my_memory_pool", [target])]) test_runner = AOTTestRunner( pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": usmp_algo}, prologue=f""" @@ -482,9 +542,7 @@ def test_tflite_model_u4_usecase_single_external_pool(model_url, usmp_algo): pool_name = "my_memory_pool" target = tvm.target.Target("c") - workspace_memory_pools = WorkspaceMemoryPools( - [PoolInfo(pool_name, {target: PoolInfo.READ_WRITE_ACCESS})] - ) + workspace_memory_pools = WorkspaceMemoryPools([WorkspacePoolInfo(pool_name, [target])]) tflite_model_file = tf_testing.get_workload_official( model_url[0], @@ -552,10 +610,10 @@ def test_tflite_model_u4_usecase_two_external_pools(model_url, usmp_algo): target = tvm.target.Target("c") workspace_memory_pools = WorkspaceMemoryPools( [ - PoolInfo( - "my_memory_pool_1", {target: PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=2500000 + WorkspacePoolInfo( + "my_memory_pool_1", [target], PoolInfoProperties(size_hint_bytes=2500000) ), - PoolInfo("my_memory_pool_2", {target: PoolInfo.READ_WRITE_ACCESS}), + WorkspacePoolInfo("my_memory_pool_2", [target]), ] ) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 8449782f4589..a5408ef069e1 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -57,6 +57,7 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handle, placeholder_164: T.handle, T_cast_76: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) + sid_21 = T.allocate_const([0,1,2,3,4,5,6,7], "int8", [8]) placeholder_165 = T.match_buffer(placeholder_162, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_166 = T.match_buffer(placeholder_163, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_167 = T.match_buffer(placeholder_164, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) @@ -90,19 +91,17 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl # fmt: on -@pytest.mark.parametrize("alignment_and_size", [(1, 663552), (10, 663560)]) -def test_global_allocates(alignment_and_size): - alignment = alignment_and_size[0] - size = alignment_and_size[1] +@pytest.mark.parametrize("alignment,size,consts", [(1, 663552, 0), (10, 663560, 0)]) +def test_global_allocates(alignment, size, consts): primfunc = primfunc_global_allocates + assert tvm.tir.analysis.calculate_constant_bytes(primfunc, alignment) == consts assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, alignment) == size -@pytest.mark.parametrize("alignment_and_size", [(1, 1566720), (100, 1567100)]) -def test_local_allocates(alignment_and_size): - alignment = alignment_and_size[0] - size = alignment_and_size[1] +@pytest.mark.parametrize("alignment,size,consts", [(1, 1566720, 8), (100, 1567100, 100)]) +def test_local_allocates(alignment, size, consts): primfunc = primfunc_local_allocates + assert tvm.tir.analysis.calculate_constant_bytes(primfunc, alignment) == consts assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, alignment) == size diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 548fd96676a0..9d30a0d19589 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -22,6 +22,7 @@ from tvm.tir import stmt_functor from tvm.tir.usmp import utils as usmp_utils from tvm.target import Target +from tvm import WorkspacePoolInfo, PoolInfoProperties def _replace_stmt_with_buf_var_names(buffer_info_map): @@ -98,10 +99,10 @@ def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): def test_no_pool_error(): target = Target("c") - tiny_workspace_pool = usmp_utils.PoolInfo( - pool_name="tiny_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, - size_hint_bytes=10, + tiny_workspace_pool = WorkspacePoolInfo( + "tiny_workspace", + [target], + PoolInfoProperties(size_hint_bytes=10), ) bi_a = usmp_utils.BufferInfo( name_hint="bi_a", size_bytes=10, pool_candidates=[tiny_workspace_pool] @@ -129,9 +130,9 @@ def test_name_based_ordering(algorithm): def _test(): target = Target("c") - global_workspace_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_workspace_pool = WorkspacePoolInfo( + "global_workspace", + [target], ) bi_a = usmp_utils.BufferInfo( name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] @@ -183,9 +184,9 @@ def test_linear(algorithm, workspace_size): bi_f """ target = Target("c") - global_workspace_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_workspace_pool = WorkspacePoolInfo( + "global_workspace", + [target], ) bi_a = usmp_utils.BufferInfo( name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] @@ -250,9 +251,9 @@ def test_fanout(algorithm, workspace_size): bi_g """ target = Target("c") - global_workspace_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_workspace_pool = WorkspacePoolInfo( + "global_workspace", + targets=[target], ) bi_a = usmp_utils.BufferInfo( name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] @@ -372,13 +373,14 @@ def run_model(input: T.handle, output: T.handle) -> None: ) def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size): target = Target("c") - fast_memory_pool = usmp_utils.PoolInfo( - pool_name="fast_memory", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, - size_hint_bytes=200704, + fast_memory_pool = WorkspacePoolInfo( + "fast_memory", + [target], + PoolInfoProperties(size_hint_bytes=200704), ) - slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + slow_memory_pool = WorkspacePoolInfo( + "slow_memory", + [target], ) tir_mod = MobilenetStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) @@ -538,9 +540,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place ) def test_resnet_subgraph(algorithm, workspace_size): target = Target("c") - global_workspace_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_workspace_pool = WorkspacePoolInfo( + "global_workspace", + [target], ) tir_mod = ResnetStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) diff --git a/tests/python/unittest/test_tir_usmp_algo_hill_climb.py b/tests/python/unittest/test_tir_usmp_algo_hill_climb.py index 44b4e6636b6c..b486581064f9 100644 --- a/tests/python/unittest/test_tir_usmp_algo_hill_climb.py +++ b/tests/python/unittest/test_tir_usmp_algo_hill_climb.py @@ -19,7 +19,8 @@ import random import tvm import tvm.testing -from tvm.tir.usmp.utils import BufferInfo, PoolInfo +from tvm.tir.usmp.utils import BufferInfo +from tvm import WorkspacePoolInfo, PoolInfoProperties def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): @@ -63,7 +64,13 @@ def _verify_all_conflicts(buffer_pool_allocations): _verify_conflicts(buffer_info, pool_allocation, buffer_pool_allocations) -def test_bounded(random_len=150, pools=[PoolInfo("default", {}, 65535), PoolInfo("slow", {})]): +def test_bounded( + random_len=150, + pools=[ + WorkspacePoolInfo("default", [], PoolInfoProperties(65535)), + WorkspacePoolInfo("slow", []), + ], +): """Tests two pools, one is bounded and one is not limited""" random.seed(0) mem_range = [BufferInfo(str(i), random.randrange(1, 65535), pools) for i in range(random_len)] @@ -351,7 +358,7 @@ def test_random_intervals(interval_len=16): def run_intervals(intervals): """Helper to run intervals""" expected_mem = find_maximum_from_intervals(intervals) - pools = [PoolInfo("default", {})] + pools = [WorkspacePoolInfo("default", [])] buffers = [] # populate for i, (start, stop, size) in enumerate(intervals): diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 22b3d5826b3b..301dc16d2127 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -25,6 +25,7 @@ from tvm.tir import PrimFunc from tvm.tir.usmp import utils as usmp_utils from tvm.target import Target +from tvm import WorkspacePoolInfo, ConstantPoolInfo def _replace_stmt_with_buf_var_names(buffer_info_map): @@ -54,7 +55,7 @@ def get_allocate(stmt): return allocates -def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos, constant_pool_infos): """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" def set_poolinfos(stmt): @@ -67,16 +68,27 @@ def set_poolinfos(stmt): body=stmt.body, annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, ) + elif isinstance(stmt, tvm.tir.AllocateConst): + return tvm.tir.AllocateConst( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + data_or_idx=stmt.data, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: constant_pool_infos}, + ) return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) -def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos, constant_pool_infos=None): """helper to assign poolinfos to allocate nodes in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): - ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc( + basefunc, pool_infos, constant_pool_infos + ) return ret @@ -165,12 +177,8 @@ def run_model(input: T.handle, output: T.handle) -> None: def test_linear(): target = Target("c") - fast_memory_pool = usmp_utils.PoolInfo( - pool_name="fast_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} - ) - slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} - ) + fast_memory_pool = WorkspacePoolInfo(pool_name="fast_memory", targets=[target]) + slow_memory_pool = WorkspacePoolInfo(pool_name="slow_memory", targets=[target]) tir_mod = LinearStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = _assign_poolinfos_to_allocates_in_irmodule( @@ -284,9 +292,9 @@ def run_model(input: T.handle, output: T.handle) -> None: def test_parallel_serial_mixed_for_loops(): target = Target("c") - global_ws_pool = usmp_utils.PoolInfo( + global_ws_pool = WorkspacePoolInfo( pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + targets=[target], ) all_serial_tir_mod = AllSerialForLoops all_serial_tir_mod = _assign_targets_to_primfuncs_irmodule(all_serial_tir_mod, target) @@ -651,9 +659,9 @@ def run_model(input: T.handle, output: T.handle) -> None: def test_inception_structure(): target = Target("c") - global_ws_pool = usmp_utils.PoolInfo( + global_ws_pool = WorkspacePoolInfo( pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + targets=[target], ) tir_mod = InceptionStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) @@ -1346,7 +1354,12 @@ def run_model(data: T.handle, output: T.handle) -> None: sid_18 = T.allocate([3456], "int8", "global.workspace") sid_19 = T.allocate([3456], "int8", "global.workspace") sid_20 = T.allocate([3456], "int8", "global.workspace") - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8.data, dtype="int32")) + + sid_21 = T.allocate_const([0,1,2,3,4,5,6,7,8,9], "int8", [10]) + sid_22 = T.allocate_const([1], "int8", [1]) + sid_23 = T.allocate_const([2,1], "int8", [3456]) + + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_23.data, dtype="int32")) T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8.data, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7.data, dtype="int32")) T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7.data, sid_6.data, dtype="int32")) T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12.data, dtype="int32")) @@ -1365,19 +1378,27 @@ def run_model(data: T.handle, output: T.handle) -> None: def test_multiple_calls_to_same_primfunc(): target = Target("c") - global_ws_pool = usmp_utils.PoolInfo( + global_ws_pool = WorkspacePoolInfo( pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + targets=[target], ) + global_const_pool = ConstantPoolInfo( + pool_name="global_constants", + targets=[target], + ) + tir_mod = MultipleCallsToSamePrimFuncModule tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) - tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [global_ws_pool], [global_const_pool] + ) main_func = tir_mod["run_model"] buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) assert buffer_info_analysis.memory_pressure == 11424 buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts) # check conflicts + _verify_conflicts("sid_23", ["sid_22", "sid_21"], buffer_info_map) _verify_conflicts( "sid_6", [ diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index ce8675f575ee..0a3e39b52f46 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -22,6 +22,7 @@ from tvm.tir import stmt_functor from tvm.tir.usmp import utils as usmp_utils from tvm.target import Target +from tvm import WorkspacePoolInfo, PoolInfoProperties def _get_primfuncs_from_module(module): @@ -231,13 +232,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def test_mobilenet_subgraph(): target = Target("c") - fast_memory_pool = usmp_utils.PoolInfo( - pool_name="fast_memory", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, - size_hint_bytes=200704, + fast_memory_pool = WorkspacePoolInfo( + "fast_memory", + [target], + PoolInfoProperties(size_hint_bytes=200704), ) - slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + slow_memory_pool = WorkspacePoolInfo( + "slow_memory", + [target], ) tir_mod = LinearStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) @@ -557,9 +559,9 @@ def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output def test_resnet_subgraph(): target = Target("c") - global_workspace_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_workspace_pool = WorkspacePoolInfo( + "global_workspace", + [target], ) tir_mod = ResnetStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index e6add3a5cfd3..2034b072838d 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -22,7 +22,7 @@ from tvm.tir import stmt_functor from tvm.tir.usmp import utils as usmp_utils from tvm.target import Target - +from tvm import WorkspacePoolInfo, PoolInfoProperties # fmt: off @tvm.script.ir_module @@ -97,29 +97,27 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def test_create_pool_info(): target = Target("c") - pool_info = usmp_utils.PoolInfo( - pool_name="foo_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + pool_info = WorkspacePoolInfo( + "foo_workspace", + [target], ) assert pool_info.pool_name == "foo_workspace" - assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} # default pool size constraint assert pool_info.size_hint_bytes == -1 - pool_info = usmp_utils.PoolInfo( - pool_name="bar_workspace", - target_access={target: usmp_utils.PoolInfo.READ_ONLY_ACCESS}, - size_hint_bytes=1425, + pool_info = WorkspacePoolInfo( + "bar_workspace", + [target], + PoolInfoProperties(size_hint_bytes=1425), ) assert pool_info.pool_name == "bar_workspace" - assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_ONLY_ACCESS} assert pool_info.size_hint_bytes == 1425 def test_create_buffer_info(): - global_ws_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_ws_pool = WorkspacePoolInfo( + "global_workspace", + [Target("c")], ) buffer_info_obj = tvm.tir.usmp.BufferInfo( name_hint="buf1", size_bytes=256, pool_candidates=[global_ws_pool] @@ -138,9 +136,9 @@ def test_create_buffer_info(): def test_create_pool_allocation(): - pool_info = usmp_utils.PoolInfo( - pool_name="foo_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + pool_info = WorkspacePoolInfo( + "foo_workspace", + [Target("c")], ) pool_allocation = usmp_utils.PoolAllocation(pool_info=pool_info, byte_offset=64) assert pool_allocation.pool_info == pool_info @@ -184,9 +182,9 @@ def _assign_targets_to_primfuncs_irmodule(mod, target): def test_create_array_buffer_info(): target = Target("c") - global_ws_pool = usmp_utils.PoolInfo( - pool_name="global_workspace", - target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + global_ws_pool = WorkspacePoolInfo( + "global_workspace", + [target], ) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") tir_mod = LinearStructure