Skip to content

Commit 321e30f

Browse files
committed
[TIR][USMP] further changes for extract buffer info
* moved the comparison to a lambda * lint fixes Change-Id: If917a3ec12d2a5689eb584e0ac5918a39f9ac12e
1 parent 0271916 commit 321e30f

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

python/tvm/tir/ir_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,11 @@ def allocate(self, dtype, shape, name="buf", scope="", pinned_memory=""):
422422
buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
423423
if not isinstance(shape, (list, tuple, _container.Array)):
424424
shape = [shape]
425-
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x, pinned_memory))
425+
self.emit(
426+
lambda x: _stmt.Allocate(
427+
buffer_var, dtype, shape, const(1, dtype="uint1"), x, pinned_memory
428+
)
429+
)
426430
return BufferVar(self, buffer_var, shape, dtype)
427431

428432
def pointer(self, content_type, name="ptr", scope=""):

src/tir/usmp/analysis/extract_buffer_info.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,6 @@ Map<tir::Stmt, BufferInfo> BufferInfoExtractor::operator()(const PrimFunc& main_
212212
size_t tick;
213213
LivenessEventType le_type;
214214
Allocate allocate;
215-
bool operator<(const LivenessEvent& other) {
216-
if (tick < other.tick) {
217-
return true;
218-
} else if (tick == other.tick && le_type == START && other.le_type == END) {
219-
return true;
220-
}
221-
return false;
222-
}
223-
224215
bool operator==(const LivenessEvent& other) {
225216
if (tick == other.tick && le_type == other.le_type && allocate == other.allocate) {
226217
return true;
@@ -252,7 +243,15 @@ Map<tir::Stmt, BufferInfo> BufferInfoExtractor::operator()(const PrimFunc& main_
252243
le_events.push_back(le_event_end);
253244
}
254245

255-
std::sort(le_events.begin(), le_events.end());
246+
std::sort(le_events.begin(), le_events.end(),
247+
[](const LivenessEvent& lhs, const LivenessEvent& rhs) {
248+
if (lhs.tick < rhs.tick) {
249+
return true;
250+
} else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) {
251+
return true;
252+
}
253+
return false;
254+
});
256255
std::unordered_set<Allocate, ObjectPtrHash, ObjectPtrEqual> open_set;
257256
for (const auto& le_event : le_events) {
258257
if (le_event.le_type == START) {

0 commit comments

Comments
 (0)