Skip to content

Commit 55849e6

Browse files
authored
[USMP] adding support for U2 and U3 usecases (#10193)
This commit adds a MemoryPools argument for the compilation flow according to RFC0029. Moreover, it is used to provide support for external pools from the application layer that could be pinned for different memories and/or be reused between multiple inferences of a model.
1 parent bb60ee9 commit 55849e6

File tree

28 files changed

+1041
-400
lines changed

28 files changed

+1041
-400
lines changed

include/tvm/ir/memory_pools.h

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/ir/memory_pools.h
22+
* \brief The object definition for relay.build argument type of memory pools
23+
*/
24+
#ifndef TVM_IR_MEMORY_POOLS_H_
25+
#define TVM_IR_MEMORY_POOLS_H_
26+
27+
#include <tvm/runtime/registry.h>
28+
#include <tvm/target/target.h>
29+
30+
namespace tvm {
31+
32+
/*!
33+
* \brief Describes a pool of memory accessible by one or more targets.
34+
*/
35+
struct PoolInfoNode : public Object {
36+
/*! \brief The name of the memory pool */
37+
String pool_name;
38+
/*! \brief The expected size hint to be used by the allocator.
39+
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
40+
* to indicate the pool is not size restricted.
41+
*/
42+
Integer size_hint_bytes;
43+
/*! \brief The accessibility from each Target */
44+
Map<Target, String> target_access; // 'rw' or 'ro'
45+
/*! \brief The clock frequency of the memory in Hz */
46+
Integer clock_frequency_hz;
47+
/*! \brief The read bandwidth in bytes/cycle */
48+
Integer read_bandwidth_bytes_per_cycle;
49+
/*! \brief The write bandwidth in bytes/cycle */
50+
Integer write_bandwidth_bytes_per_cycle;
51+
/*! \brief The read latency in cycles */
52+
Integer read_latency_cycles;
53+
/*! \brief The write latency in cycles */
54+
Integer write_latency_cycles;
55+
/*! \brief The burst length in bytes for each Target */
56+
Map<Target, Integer> target_burst_bytes;
57+
/*! \brief Whether pool is internally generated.
58+
* The internal pools will be generated as part of
59+
* the entry point code generation of the executor
60+
*/
61+
bool is_internal = false;
62+
63+
void VisitAttrs(tvm::AttrVisitor* v) {
64+
v->Visit("pool_name", &pool_name);
65+
v->Visit("size_hint_bytes", &size_hint_bytes);
66+
v->Visit("target_access", &target_access);
67+
v->Visit("clock_frequency_hz", &clock_frequency_hz);
68+
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
69+
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
70+
v->Visit("read_latency_cycles", &read_latency_cycles);
71+
v->Visit("write_latency_cycles", &write_latency_cycles);
72+
v->Visit("target_burst_bytes", &target_burst_bytes);
73+
v->Visit("is_internal", &is_internal);
74+
}
75+
76+
bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
77+
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
78+
equal(target_access, other->target_access) &&
79+
equal(target_access, other->target_access) &&
80+
equal(clock_frequency_hz, other->clock_frequency_hz) &&
81+
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
82+
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
83+
equal(read_latency_cycles, other->read_latency_cycles) &&
84+
equal(write_latency_cycles, other->write_latency_cycles) &&
85+
equal(target_burst_bytes, other->target_burst_bytes) &&
86+
equal(is_internal, other->is_internal);
87+
}
88+
89+
void SHashReduce(SHashReducer hash_reduce) const {
90+
hash_reduce(pool_name);
91+
hash_reduce(size_hint_bytes);
92+
hash_reduce(target_access);
93+
hash_reduce(clock_frequency_hz);
94+
hash_reduce(read_bandwidth_bytes_per_cycle);
95+
hash_reduce(write_bandwidth_bytes_per_cycle);
96+
hash_reduce(read_latency_cycles);
97+
hash_reduce(write_latency_cycles);
98+
hash_reduce(target_burst_bytes);
99+
hash_reduce(is_internal);
100+
}
101+
102+
static constexpr const char* _type_key = "ir.PoolInfo";
103+
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
104+
};
105+
106+
class PoolInfo : public ObjectRef {
107+
public:
108+
/*!
109+
* \brief The string parameter to indicate read and write access to a pool
110+
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
111+
* python/tvm/ir/memory_pools.py
112+
*/
113+
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
114+
/*!
115+
* \brief The string parameter to indicate read only access to a pool
116+
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
117+
* python/tvm/ir/memory_pools.py
118+
*/
119+
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
120+
/*! \brief The PoolSize is unrestricted for the memory planner */
121+
static const int kUnrestrictedPoolSizeHint = -1;
122+
/*! \brief The clock frequency is not known */
123+
static const int kUnknownClockFrequency = -1;
124+
/*! \brief The read bandwidth is not known */
125+
static const int kUnknownReadBandwidth = -1;
126+
/*! \brief The write bandwidth is not known */
127+
static const int kUnknownWriteBandwidth = -1;
128+
129+
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
130+
Integer size_hint_bytes = kUnrestrictedPoolSizeHint,
131+
Integer clock_frequency_hz = kUnknownClockFrequency,
132+
Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth,
133+
Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth,
134+
Integer read_latency_cycles = 0, Integer write_latency_cycles = 0,
135+
Map<Target, Integer> target_burst_bytes = {}, Bool is_internal = Bool(false));
136+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
137+
};
138+
139+
struct WorkspaceMemoryPoolsNode : public Object {
140+
Array<PoolInfo> pools;
141+
142+
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pools", &pools); }
143+
144+
bool SEqualReduce(const WorkspaceMemoryPoolsNode* other, SEqualReducer equal) const {
145+
return equal(pools, other->pools);
146+
}
147+
148+
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pools); }
149+
150+
static constexpr const char* _type_key = "ir.WorkspaceMemoryPools";
151+
TVM_DECLARE_FINAL_OBJECT_INFO(WorkspaceMemoryPoolsNode, Object);
152+
};
153+
154+
class WorkspaceMemoryPools : public ObjectRef {
155+
public:
156+
TVM_DLL WorkspaceMemoryPools(Array<PoolInfo> pools);
157+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(WorkspaceMemoryPools, ObjectRef, WorkspaceMemoryPoolsNode);
158+
};
159+
160+
} // namespace tvm
161+
162+
#endif // TVM_IR_MEMORY_POOLS_H_

include/tvm/ir/module.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,15 @@ constexpr const char* kExecutor = "executor";
495495
*/
496496
constexpr const char* kRuntime = "runtime";
497497

498+
/*!
499+
* \brief workspace memory pools of the module
500+
*
501+
* Type: WorkspaceMemoryPools
502+
*
503+
* \sa tvm::WorkspaceMemoryPools
504+
*/
505+
constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";
506+
498507
} // namespace attr
499508
} // namespace tvm
500509
#endif // TVM_IR_MODULE_H_

include/tvm/tir/usmp/utils.h

Lines changed: 1 addition & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#define TVM_TIR_USMP_UTILS_H_
2727

2828
#include <tvm/ir/expr.h>
29+
#include <tvm/ir/memory_pools.h>
2930
#include <tvm/runtime/device_api.h>
3031
#include <tvm/target/target.h>
3132
#include <tvm/tir/stmt.h>
@@ -44,111 +45,6 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
4445
namespace tir {
4546
namespace usmp {
4647

47-
/*!
48-
* \brief Describes a pool of memory accessible by one or more targets.
49-
*/
50-
struct PoolInfoNode : public Object {
51-
/*! \brief The name of the memory pool */
52-
String pool_name;
53-
/*! \brief The expected size hint to be used by the allocator.
54-
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
55-
* to indicate the pool is not size restricted.
56-
*/
57-
Integer size_hint_bytes;
58-
/*! \brief The accessibility from each Target */
59-
Map<Target, String> target_access; // 'rw' or 'ro'
60-
/*! \brief The clock frequency of the memory in Hz */
61-
Integer clock_frequency_hz;
62-
/*! \brief The read bandwidth in bytes/cycle */
63-
Integer read_bandwidth_bytes_per_cycle;
64-
/*! \brief The write bandwidth in bytes/cycle */
65-
Integer write_bandwidth_bytes_per_cycle;
66-
/*! \brief The read latency in cycles */
67-
Integer read_latency_cycles;
68-
/*! \brief The write latency in cycles */
69-
Integer write_latency_cycles;
70-
/*! \brief The burst length in bytes for each Target */
71-
Map<Target, Integer> target_burst_bytes;
72-
/*! \brief Whether pool is internally generated.
73-
* The internal pools will be generated as part of
74-
* the entry point code generation of the executor
75-
*/
76-
bool is_internal = false;
77-
78-
void VisitAttrs(tvm::AttrVisitor* v) {
79-
v->Visit("pool_name", &pool_name);
80-
v->Visit("size_hint_bytes", &size_hint_bytes);
81-
v->Visit("target_access", &target_access);
82-
v->Visit("clock_frequency_hz", &clock_frequency_hz);
83-
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
84-
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
85-
v->Visit("read_latency_cycles", &read_latency_cycles);
86-
v->Visit("write_latency_cycles", &write_latency_cycles);
87-
v->Visit("target_burst_bytes", &target_burst_bytes);
88-
v->Visit("is_internal", &is_internal);
89-
}
90-
91-
bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
92-
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
93-
equal(target_access, other->target_access) &&
94-
equal(target_access, other->target_access) &&
95-
equal(clock_frequency_hz, other->clock_frequency_hz) &&
96-
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
97-
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
98-
equal(read_latency_cycles, other->read_latency_cycles) &&
99-
equal(write_latency_cycles, other->write_latency_cycles) &&
100-
equal(target_burst_bytes, other->target_burst_bytes) &&
101-
equal(is_internal, other->is_internal);
102-
}
103-
104-
void SHashReduce(SHashReducer hash_reduce) const {
105-
hash_reduce(pool_name);
106-
hash_reduce(size_hint_bytes);
107-
hash_reduce(target_access);
108-
hash_reduce(clock_frequency_hz);
109-
hash_reduce(read_bandwidth_bytes_per_cycle);
110-
hash_reduce(write_bandwidth_bytes_per_cycle);
111-
hash_reduce(read_latency_cycles);
112-
hash_reduce(write_latency_cycles);
113-
hash_reduce(target_burst_bytes);
114-
hash_reduce(is_internal);
115-
}
116-
117-
static constexpr const char* _type_key = "tir.usmp.PoolInfo";
118-
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
119-
};
120-
121-
class PoolInfo : public ObjectRef {
122-
public:
123-
/*!
124-
* \brief The string parameter to indicate read and write access to a pool
125-
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
126-
* python/tvm/tir/usmp/utils.py
127-
*/
128-
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
129-
/*!
130-
* \brief The string parameter to indicate read only access to a pool
131-
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
132-
* python/tvm/tir/usmp/utils.py
133-
*/
134-
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
135-
/*! \brief The PoolSize is unrestricted for the memory planner */
136-
static const int kUnrestrictedPoolSizeHint = -1;
137-
/*! \brief The clock frequency is not known */
138-
static const int kUnknownClockFrequency = -1;
139-
/*! \brief The read bandwidth is not known */
140-
static const int kUnknownReadBandwidth = -1;
141-
/*! \brief The write bandwidth is not known */
142-
static const int kUnknownWriteBandwidth = -1;
143-
144-
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
145-
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
146-
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
147-
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes,
148-
Bool is_internal);
149-
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
150-
};
151-
15248
/*!
15349
* \brief Describes an abstract memory buffer that will get allocated inside a pool.
15450
* The actual memory buffer in represented by PoolAllocationNode after static memory planning.

python/tvm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from .ir import transform
4444
from .ir import instrument
4545
from .ir import container
46+
from .ir import PoolInfo
47+
from .ir import WorkspaceMemoryPools
4648
from . import ir
4749

4850
# tvm.tir

python/tvm/ir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .module import IRModule
3131
from .attrs import Attrs, DictAttrs, make_node
3232
from .container import Array, Map
33+
from .memory_pools import PoolInfo, WorkspaceMemoryPools
3334

3435
from . import transform
3536
from . import instrument

0 commit comments

Comments
 (0)