Skip to content

Commit e79fac6

Browse files
authored
[Metaschedule] get_top_k should not return not built records (#13824)
* [Metaschedule] get_top_k should not return not built records * [Metaschedule][NFC] GetTopK extra polishing
1 parent 1bc6dd4 commit e79fac6

File tree

3 files changed

+38
-31
lines changed

3 files changed

+38
-31
lines changed

src/meta_schedule/database/json_database.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,14 @@ class JSONDatabaseNode : public DatabaseNode {
127127
Array<TuningRecord> results;
128128
results.reserve(top_k);
129129
for (const TuningRecord& record : this->tuning_records_) {
130-
if (!record->run_secs.defined() || record->run_secs.value().empty()) {
130+
auto run_secs = record->run_secs;
131+
if (!run_secs.defined() || run_secs.value().empty() ||
132+
std::all_of(run_secs.value().begin(), run_secs.value().end(),
133+
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
134+
[](tvm::FloatImm v) {
135+
return v.defined() &&
136+
v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
137+
})) {
131138
continue;
132139
}
133140
if (record->workload.same_as(workload) ||

src/meta_schedule/database/memory_database.cc

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,34 @@ class MemoryDatabaseNode : public DatabaseNode {
6565
if (top_k == 0) {
6666
return {};
6767
}
68-
std::vector<std::pair<double, TuningRecord>> results;
68+
std::vector<TuningRecord> results;
6969
results.reserve(records.size());
7070
for (const TuningRecord& record : records) {
71-
if (!record->run_secs.defined()) {
72-
continue;
73-
}
74-
Array<FloatImm> run_secs = record->run_secs.value();
75-
if (run_secs.empty()) {
71+
auto run_secs = record->run_secs;
72+
if (!run_secs.defined() || run_secs.value().empty() ||
73+
std::all_of(run_secs.value().begin(), run_secs.value().end(),
74+
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
75+
[](tvm::FloatImm v) {
76+
return v.defined() &&
77+
v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
78+
})) {
7679
continue;
7780
}
7881
if (record->workload.same_as(workload) ||
7982
WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
80-
double sum = 0.0;
81-
for (const FloatImm& i : run_secs) {
82-
sum += i->value;
83-
}
84-
results.emplace_back(sum / run_secs.size(), record);
83+
results.emplace_back(record);
8584
}
8685
}
87-
std::sort(results.begin(), results.end());
88-
auto begin = results.begin();
89-
auto end = results.end();
86+
std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
9087
if (results.size() > static_cast<size_t>(top_k)) {
91-
end = begin + top_k;
92-
}
93-
Array<TuningRecord> ret;
94-
ret.reserve(end - begin);
95-
while (begin != end) {
96-
ret.push_back(begin->second);
97-
++begin;
98-
}
99-
if (ret.size() < static_cast<size_t>(top_k)) {
100-
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
101-
"enough valid records in the database for this workload.";
88+
return {results.begin(), results.end() + top_k};
89+
} else {
90+
if (results.size() < static_cast<size_t>(top_k)) {
91+
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
92+
"enough valid records in the database for this workload.";
93+
}
94+
return results;
10295
}
103-
return ret;
10496
}
10597

10698
Array<TuningRecord> GetAllTuningRecords() final { return records; }

tests/python/unittest/test_meta_schedule_database.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,21 +554,29 @@ def call_get_top_k(run_secs_list, database, k):
554554

555555
@pytest.mark.parametrize(
556556
"k,expected",
557-
[(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
557+
[
558+
(0, []),
559+
(4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
560+
(5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
561+
],
558562
)
559563
def test_memory_database_get_top_k(k, expected):
560-
run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
564+
run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]]
561565
database = ms.database.MemoryDatabase()
562566
result = call_get_top_k(run_secs_list, database, k)
563567
assert result == expected
564568

565569

566570
@pytest.mark.parametrize(
567571
"k,expected",
568-
[(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5]])],
572+
[
573+
(0, []),
574+
(4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
575+
(5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]),
576+
],
569577
)
570578
def test_json_database_get_top_k(k, expected):
571-
run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
579+
run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]]
572580
with tempfile.TemporaryDirectory() as tmpdir:
573581
database = _create_tmp_database(tmpdir)
574582
result = call_get_top_k(run_secs_list, database, k)

0 commit comments

Comments
 (0)