Skip to content
191 changes: 180 additions & 11 deletions include/tvm/ir/memory_pools.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>

struct TVMConstantInfo;
namespace tvm {

/*!
* \brief Describes a pool of memory accessible by one or more targets.
*/
struct PoolInfoNode : public Object {
public:
/*! \brief The name of the memory pool */
String pool_name;
/*! \brief The expected size hint to be used by the allocator.
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
* to indicate the pool is not size restricted.
*/
Integer size_hint_bytes;
/*! \brief The accessibility from each Target */
Map<Target, String> target_access; // 'rw' or 'ro'
/*! \brief The clock frequency of the memory in Hz */
Integer clock_frequency_hz;
/*! \brief The read bandwidth in bytes/cycle */
Expand All @@ -60,10 +60,12 @@ struct PoolInfoNode : public Object {
*/
bool is_internal = false;

/*! \brief The targets linked to the pool */
Array<Target> targets;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pool_name", &pool_name);
v->Visit("size_hint_bytes", &size_hint_bytes);
v->Visit("target_access", &target_access);
v->Visit("clock_frequency_hz", &clock_frequency_hz);
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
Expand All @@ -75,8 +77,6 @@ struct PoolInfoNode : public Object {

bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
equal(target_access, other->target_access) &&
equal(target_access, other->target_access) &&
equal(clock_frequency_hz, other->clock_frequency_hz) &&
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
Expand All @@ -89,7 +89,6 @@ struct PoolInfoNode : public Object {
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(pool_name);
hash_reduce(size_hint_bytes);
hash_reduce(target_access);
hash_reduce(clock_frequency_hz);
hash_reduce(read_bandwidth_bytes_per_cycle);
hash_reduce(write_bandwidth_bytes_per_cycle);
Expand All @@ -100,7 +99,7 @@ struct PoolInfoNode : public Object {
}

static constexpr const char* _type_key = "ir.PoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
TVM_DECLARE_BASE_OBJECT_INFO(PoolInfoNode, Object);
};

/*!
Expand Down Expand Up @@ -129,18 +128,166 @@ static const int kUnknownReadBandwidth = -1;
/*! \brief The write bandwidth is not known */
static const int kUnknownWriteBandwidth = -1;

/*! \brief Base class for WorkspacePoolInfo and ConstantPoolInfo */
class PoolInfo : public ObjectRef {
public:
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
Integer size_hint_bytes = kUnrestrictedPoolSizeHint,
protected:
TVM_DLL PoolInfo(String pool_name, Integer size_hint_bytes = kUnrestrictedPoolSizeHint,
Integer clock_frequency_hz = kUnknownClockFrequency,
Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth,
Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth,
Integer read_latency_cycles = 0, Integer write_latency_cycles = 0,
Map<Target, Integer> target_burst_bytes = {}, Bool is_internal = Bool(false));
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);

public:
TVM_DEFINE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
};

/*!
* \brief Describes a pool of memory properties
*/
struct PoolInfoPropertiesNode : public Object {
/*! \brief The expected size hint to be used by the allocator.
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
* to indicate the pool is not size restricted.
*/
Integer size_hint_bytes = kUnrestrictedPoolSizeHint;
/*! \brief The clock frequency of the memory in Hz */
Integer clock_frequency_hz = kUnknownClockFrequency;
/*! \brief The read bandwidth in bytes/cycle */
Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth;
/*! \brief The write bandwidth in bytes/cycle */
Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth;
/*! \brief The read latency in cycles */
Integer read_latency_cycles = 0;
/*! \brief The write latency in cycles */
Integer write_latency_cycles = 0;
/*! \brief The burst length in bytes for each Target */
Map<Target, Integer> target_burst_bytes{};
/*! \brief Whether pool is internally generated.
* The internal pools will be generated as part of
* the entry point code generation of the executor
*/
bool is_internal = false;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("size_hint_bytes", &size_hint_bytes);
v->Visit("clock_frequency_hz", &clock_frequency_hz);
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
v->Visit("read_latency_cycles", &read_latency_cycles);
v->Visit("write_latency_cycles", &write_latency_cycles);
v->Visit("target_burst_bytes", &target_burst_bytes);
v->Visit("is_internal", &is_internal);
}

bool SEqualReduce(const PoolInfoPropertiesNode* other, SEqualReducer equal) const {
return equal(size_hint_bytes, other->size_hint_bytes) &&
equal(clock_frequency_hz, other->clock_frequency_hz) &&
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
equal(read_latency_cycles, other->read_latency_cycles) &&
equal(write_latency_cycles, other->write_latency_cycles) &&
equal(target_burst_bytes, other->target_burst_bytes) &&
equal(is_internal, other->is_internal);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(size_hint_bytes);
hash_reduce(clock_frequency_hz);
hash_reduce(read_bandwidth_bytes_per_cycle);
hash_reduce(write_bandwidth_bytes_per_cycle);
hash_reduce(read_latency_cycles);
hash_reduce(write_latency_cycles);
hash_reduce(target_burst_bytes);
hash_reduce(is_internal);
}

static constexpr const char* _type_key = "ir.PoolInfoProperties";
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoPropertiesNode, Object);
};

class PoolInfoProperties : public ObjectRef {
public:
TVM_DLL PoolInfoProperties(Integer size_hint_bytes,
Integer clock_frequency_hz = kUnknownClockFrequency,
Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth,
Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth,
Integer read_latency_cycles = 0, Integer write_latency_cycles = 0,
Map<Target, Integer> target_burst_bytes = {},
Bool is_internal = Bool(false));
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfoProperties, ObjectRef, PoolInfoPropertiesNode);
};

/* \brief Represents RW memory area */
struct WorkspacePoolInfoNode : public PoolInfoNode {
static constexpr const char* _type_key = "ir.WorkspacePoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(WorkspacePoolInfoNode, PoolInfoNode);
};

class WorkspacePoolInfo : public PoolInfo {
public:
TVM_DLL WorkspacePoolInfo(
String pool_name, Array<Target> targets,
PoolInfoProperties properties = PoolInfoProperties(kUnrestrictedPoolSizeHint));
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(WorkspacePoolInfo, PoolInfo, WorkspacePoolInfoNode);
};

/*
* \brief The ConstantInfoNode contains numeric literal in RO pool
* Used to initialise RO memory in ConstantPoolInfo
*/
struct ConstantInfoNode : public Object {
String name_hint;
Integer byte_offset;
runtime::NDArray data;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("byte_offset", &byte_offset);
v->Visit("data", &data);
}

bool SEqualReduce(const ConstantInfoNode* other, SEqualReducer equal) const {
return equal(name_hint, other->name_hint) && equal(byte_offset, other->byte_offset) &&
equal(data, other->data);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce(byte_offset);
hash_reduce(data);
}

static constexpr const char* _type_key = "ir.ConstantInfo";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantInfoNode, Object);
};

class ConstantInfo : public ObjectRef {
public:
TVM_DLL ConstantInfo(const struct ::TVMConstantInfo* data);
ConstantInfo(String name, Integer byte_offset, runtime::NDArray data);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantInfo, ObjectRef, ConstantInfoNode);
};

/* \brief ConstantPoolInfoNode represents an RO memory area initialized with
* data from constant_info_array */
struct ConstantPoolInfoNode : public PoolInfoNode {
Array<ConstantInfo> constant_info_array;
static constexpr const char* _type_key = "ir.ConstantPoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPoolInfoNode, PoolInfoNode);
};

class ConstantPoolInfo : public PoolInfo {
public:
TVM_DLL ConstantPoolInfo(
String pool_name, Array<Target> targets, Array<ConstantInfo> constant_info_array,
PoolInfoProperties properties = PoolInfoProperties(kUnrestrictedPoolSizeHint));
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantPoolInfo, PoolInfo, ConstantPoolInfoNode);
};

/* \brief A container for WorkspacePoolInfo objects */
struct WorkspaceMemoryPoolsNode : public Object {
Array<PoolInfo> pools;

Expand All @@ -162,6 +309,28 @@ class WorkspaceMemoryPools : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(WorkspaceMemoryPools, ObjectRef, WorkspaceMemoryPoolsNode);
};

/* \brief A container for ConstantPoolInfo objects */
struct ConstantMemoryPoolsNode : public Object {
Array<ConstantPoolInfo> pools;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pools", &pools); }

bool SEqualReduce(const ConstantMemoryPoolsNode* other, SEqualReducer equal) const {
return equal(pools, other->pools);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pools); }

static constexpr const char* _type_key = "ir.ConstantMemoryPools";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantMemoryPoolsNode, Object);
};

class ConstantMemoryPools : public ObjectRef {
public:
TVM_DLL ConstantMemoryPools(Array<ConstantPoolInfo> pools);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantMemoryPools, ObjectRef, ConstantMemoryPoolsNode);
};

} // namespace tvm

#endif // TVM_IR_MEMORY_POOLS_H_
9 changes: 9 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,15 @@ constexpr const char* kRuntime = "runtime";
*/
constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";

/*!
* \brief constant memory pools of the module
*
* Type: ConstantMemoryPools
*
* \sa tvm::ConstantMemoryPools
*/
constexpr const char* kConstantMemoryPools = "constant_memory_pools";

/*
* \brief Module attribute for tir constants
*/
Expand Down
46 changes: 37 additions & 9 deletions include/tvm/runtime/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_RUNTIME_METADATA_H_
#define TVM_RUNTIME_METADATA_H_

#include <dmlc/memory_io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/metadata_base.h>
#include <tvm/runtime/metadata_types.h>
Expand All @@ -45,16 +46,10 @@ namespace metadata {
* Should be populated into the `version` field of all TVMMetadata.
*/
static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION;
} // namespace metadata
} // namespace runtime
} // namespace tvm

namespace tvm {
namespace runtime {
namespace metadata {

class Metadata;
class TensorInfo;
class ConstantInfoMetadata;

class MetadataNode : public MetadataBaseNode {
public:
Expand All @@ -66,10 +61,12 @@ class MetadataNode : public MetadataBaseNode {
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
inline int64_t num_outputs() const { return data_->num_outputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs();
inline int64_t num_pools() const { return data_->num_pools; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> pools();
inline int64_t num_workspace_pools() const { return data_->num_workspace_pools; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> workspace_pools();
inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); }
const struct ::TVMMetadata* data() const { return data_; }
ArrayAccessor<struct TVMConstantInfo, ConstantInfoMetadata> constant_pools();
inline int64_t num_constant_pools() const { return data_->num_constant_pools; }
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode);

private:
Expand Down Expand Up @@ -107,6 +104,37 @@ class TensorInfo : public MetadataBase {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode);
};

class ConstantInfoMetadataNode : public MetadataBaseNode {
public:
explicit ConstantInfoMetadataNode(const struct ::TVMConstantInfo* data) : data_{data} {}
// This name should match TVMConstantInfo after processing
static constexpr const char* _type_key = "metadata.ConstantInfoNode";
const char* get_c_struct_name() const override;
inline ::tvm::runtime::String name_hint() const {
return ::tvm::runtime::String(data_->name_hint);
}
inline size_t byte_offset() const { return data_->byte_offset; }
inline ::tvm::runtime::NDArray data() const {
::tvm::runtime::NDArray ndarray;
if (data_->data_len) {
dmlc::MemoryFixedSizeStream bytes(const_cast<void*>(data_->data_bytes), data_->data_len);
ndarray.Load(&bytes);
}
return ndarray;
}
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantInfoMetadataNode, MetadataBaseNode);

protected:
const struct ::TVMConstantInfo* data_;
};

class ConstantInfoMetadata : public MetadataBase {
public:
explicit ConstantInfoMetadata(const struct ::TVMConstantInfo* data);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantInfoMetadata, MetadataBase,
ConstantInfoMetadataNode);
};

} // namespace metadata
} // namespace runtime
} // namespace tvm
Expand Down
Loading