Skip to content

Commit ee298db

Browse files
manupakdchauhan-arm
authored andcommitted
[TIR][USMP] Added buffer info extraction pass (apache#8468)
* [TIR][USMP] Added buffer info extraction pass This commit adds a pass that takes the main (call graph of operators) TIR PrimFunc and each operators also as TIR PrimFunc. The pass will traverse through all TIR PrimFunc starting the from main. Thereafter, it will extract information from tir.allocates. Among the information, the liveness conflicts are reported. * Added test for a linear model * Added test for parallel/serial mixed for loops * Added test for a substructure of inception-style model. * Exposed buffer_info creation to python * Added member functions to update pool info * Unit tests to cover functionality of buffer_info Change-Id: I5e163ac3e83c830629a5d34ed4407c9962701c60 * [TIR][USMP] Added buffer info extraction pass Swap key-value pairs of returned values of the buffer_info extraction pass. Change-Id: Ia4f7289592bc776ef6189a41a7891038751bf31f * [TIR][USMP] Added buffer info extraction pass Updating the USMP utility tests to include tests that test creation of PoolInfo and PoolAllocation Objects. Change-Id: I5d349d0ffcac6b0160072d832dd9d5418699228e * [TIR][USMP] Added buffer info extraction pass * Removing the unnecessary header : include/tvm/tir/usmp/analysis.h * Some nits and cleanup Change-Id: Iac3ddd9428c56cd8ef49cf643e797bf6fdf4e97a * [TIR][USMP] Added buffer info extraction pass * Change the class data members to have a trailing underscore Change-Id: I71809b3c73b0bc0cd133fad1392ae8c17c895ee4 * [TIR][USMP] Added buffer info extraction pass Adding more documentation for data structures and the approach Change-Id: Ide2bfffaeff9add86853b6992017264e5d796299 * [TIR][USMP] Added buffer info extraction pass * Added more documentation * Added functionality to handle multiple calls for the same PrimFunc with a test. Change-Id: Ib7c27b3cf17f415067a224f1e57d8b928f4c7c6f * [TIR][USMP] Added buffer info extraction pass * Attaching targets to PrimFuncs in the util test case Change-Id: I82960512659a346f6242b2b5789ec1120f8ea2cf
1 parent 4922a6c commit ee298db

File tree

14 files changed

+2839
-0
lines changed

14 files changed

+2839
-0
lines changed

include/tvm/tir/usmp/utils.h

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tir/usmp/utils.h
22+
* \brief Utilities for Unified Static Memory Planner
23+
*/
24+
25+
#ifndef TVM_TIR_USMP_UTILS_H_
26+
#define TVM_TIR_USMP_UTILS_H_
27+
28+
#include <tvm/ir/expr.h>
29+
#include <tvm/target/target.h>
30+
#include <tvm/tir/stmt.h>
31+
32+
namespace tvm {
33+
namespace tir {
34+
namespace usmp {
35+
36+
/*!
37+
* \brief The string parameter to indicate read and write access to a pool
38+
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
39+
* python/tvm/tir/usmp/utils.py
40+
*/
41+
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
42+
/*!
43+
* \brief The string parameter to indicate read only access to a pool
44+
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
45+
* python/tvm/tir/usmp/utils.py
46+
*/
47+
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
48+
49+
/*!
50+
* \brief Describes a pool of memory accessible by one or more targets.
51+
*/
52+
struct PoolInfoNode : public Object {
53+
/*! \brief The name of the memory pool */
54+
String pool_name;
55+
/*! \brief The expected size hint to be used by the allocator.
56+
* The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint
57+
* to indicate the pool is not size restricted.
58+
*/
59+
Integer size_hint_bytes;
60+
/*! \brief The accessibility from each Target*/
61+
Map<Target, String> target_access; // 'rw' or 'ro'
62+
63+
void VisitAttrs(tvm::AttrVisitor* v) {
64+
v->Visit("pool_name", &pool_name);
65+
v->Visit("size_hint_bytes", &size_hint_bytes);
66+
v->Visit("target_access", &target_access);
67+
}
68+
69+
bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
70+
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
71+
equal(target_access, other->target_access);
72+
}
73+
74+
void SHashReduce(SHashReducer hash_reduce) const {
75+
hash_reduce(pool_name);
76+
hash_reduce(size_hint_bytes);
77+
hash_reduce(target_access);
78+
}
79+
80+
static constexpr const char* _type_key = "tir.usmp.PoolInfo";
81+
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
82+
};
83+
84+
/*!
85+
* \brief The PoolSize is unrestricted for the memory planner
86+
*/
87+
static const int kUnrestrictedPoolSizeHint = -1;
88+
89+
class PoolInfo : public ObjectRef {
90+
public:
91+
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
92+
Integer size_hint_bytes = kUnrestrictedPoolSizeHint);
93+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
94+
};
95+
96+
/*!
97+
* \brief Describes an abstract memory buffer that will get allocated inside a pool.
98+
* The actual memory buffer in represented by PoolAllocationNode after static memory planning.
99+
*
100+
* See also for relay-level counterparts:
101+
* relay::StorageToken (graph_plan_memory.cc)
102+
* relay::backend::StorageInfoNode (relay/backend/utils.h)
103+
* Region (python/tvm/relay/transform/memory_plan.py)
104+
*/
105+
struct BufferInfoNode : public Object {
106+
/*! \brief The name of the buffer var */
107+
String name_hint;
108+
/*! \brief The size in terms of bytes */
109+
Integer size_bytes;
110+
/*! \brief The pool candidates that this buffer can get pooled to*/
111+
Array<PoolInfo> pool_candidates;
112+
/*! \brief The byte alignment required for buffers that will placed within the pool */
113+
Integer alignment;
114+
/*! \brief The liveness conflicting other buffer info objects */
115+
Array<ObjectRef> conflicts;
116+
117+
void VisitAttrs(tvm::AttrVisitor* v) {
118+
v->Visit("name_hint", &name_hint);
119+
v->Visit("size_bytes", &size_bytes);
120+
v->Visit("pool_candidates", &pool_candidates);
121+
v->Visit("alignment", &alignment);
122+
v->Visit("conflicts", &conflicts);
123+
}
124+
125+
bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const {
126+
return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) &&
127+
equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) &&
128+
equal(conflicts, other->conflicts);
129+
}
130+
131+
void SHashReduce(SHashReducer hash_reduce) const {
132+
hash_reduce(name_hint);
133+
hash_reduce(size_bytes);
134+
hash_reduce(alignment);
135+
hash_reduce(conflicts);
136+
hash_reduce(pool_candidates);
137+
}
138+
/*!
139+
* \brief Set the liveness conflicts of this BufferInfo
140+
*
141+
* \param conflicting_buffer_info_objs An array of BufferInfo that conflicts in liveness
142+
*/
143+
TVM_DLL void SetConflicts(Array<ObjectRef> conflicting_buffer_info_objs);
144+
145+
static constexpr const char* _type_key = "tir.usmp.BufferInfo";
146+
TVM_DECLARE_FINAL_OBJECT_INFO(BufferInfoNode, Object);
147+
};
148+
149+
class BufferInfo : public ObjectRef {
150+
public:
151+
TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
152+
Integer alignment = runtime::kDefaultWorkspaceAlignment);
153+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode);
154+
};
155+
156+
/*!
157+
* \brief The pool allocation produced after the USMP algorithm
158+
*/
159+
struct PoolAllocationNode : public Object {
160+
/*! \brief The assigned PoolInfo object */
161+
PoolInfo pool_info;
162+
/*! \brief The byte offset where the tensor is supposed to be placed within the pool*/
163+
Integer byte_offset;
164+
165+
void VisitAttrs(tvm::AttrVisitor* v) {
166+
v->Visit("pool_info", &pool_info);
167+
v->Visit("byte_offset", &byte_offset);
168+
}
169+
170+
bool SEqualReduce(const PoolAllocationNode* other, SEqualReducer equal) const {
171+
return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset);
172+
}
173+
174+
void SHashReduce(SHashReducer hash_reduce) const {
175+
hash_reduce(pool_info);
176+
hash_reduce(byte_offset);
177+
}
178+
179+
static constexpr const char* _type_key = "tir.usmp.PoolAllocation";
180+
TVM_DECLARE_FINAL_OBJECT_INFO(PoolAllocationNode, Object);
181+
};
182+
183+
class PoolAllocation : public ObjectRef {
184+
public:
185+
TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset);
186+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode);
187+
};
188+
189+
/*!
190+
* \brief Convert the IR-bound BufferInfo map to an array of BufferInfo
191+
*
192+
* \param buffer_info_map IR-bound BufferInfo map
193+
*/
194+
Array<BufferInfo> CreateArrayBufferInfo(const Map<Stmt, BufferInfo>& buffer_info_map);
195+
196+
/*!
197+
* \brief The allocate node attribute to indicate candidate memory pools.
198+
* This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in
199+
* python/tvm/tir/usmp/utils.py.
200+
*/
201+
static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";
202+
203+
/*!
204+
* \brief Calculate the size of the extents in bytes
205+
*
206+
* \param op the allocate node
207+
*/
208+
Integer CalculateExtentsSize(const AllocateNode* op);
209+
210+
} // namespace usmp
211+
} // namespace tir
212+
} // namespace tvm
213+
214+
#endif // TVM_TIR_USMP_UTILS_H_

python/tvm/script/tir/scope_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self):
110110
def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
111111
condition = tvm.runtime.convert(condition)
112112
scope = tvm.runtime.convert(scope)
113+
113114
return tvm.tir.Allocate(
114115
self.buffer_var,
115116
dtype,

python/tvm/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@
5555
from . import transform
5656
from . import analysis
5757
from . import stmt_functor
58+
from . import usmp

python/tvm/tir/ir_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def allocate(self, dtype, shape, name="buf", scope=""):
411411
scope : str, optional
412412
The scope of the buffer.
413413
414+
414415
Returns
415416
-------
416417
buffer : BufferVar

python/tvm/tir/usmp/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-import, redefined-builtin
18+
"""Namespace for Unified Static Memory Planner"""
19+
20+
from . import analysis
21+
from .utils import BufferInfo

python/tvm/tir/usmp/_ffi_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""FFI APIs for tvm.tir.usmp"""
18+
import tvm._ffi
19+
20+
21+
tvm._ffi._init_api("tir.usmp", __name__)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-import, redefined-builtin
18+
"""Namespace for Unified Static Memory Planner"""
19+
20+
from .analysis import extract_buffer_info
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""FFI APIs for tvm.tir.usmp.analysis"""
18+
import tvm._ffi
19+
20+
21+
tvm._ffi._init_api("tir.usmp.analysis", __name__)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""USMP Analysis Python API for passes"""
18+
# pylint: disable=invalid-name
19+
from . import _ffi_api
20+
from ...function import PrimFunc
21+
from ....ir.module import IRModule
22+
23+
24+
def extract_buffer_info(main_func: PrimFunc, mod: IRModule):
25+
"""Convert Parallel For Loop to Serial.
26+
27+
Parameters
28+
----------
29+
main_func: tvm.tir.PrimFunc
30+
The main function containing calls to operator PrimFuncs.
31+
mod : tvm.ir.IRModule
32+
The full IRModule containing all PrimFuncs
33+
34+
Returns
35+
-------
36+
Map<tir::Stmt, BufferInfo>
37+
extracted buffer info objects
38+
"""
39+
return _ffi_api.extract_buffer_info(main_func, mod)

0 commit comments

Comments
 (0)