Skip to content

Commit 0df6961

Browse files
authored
[MetaSchedule] JSONDatabase Utilities (#11680)
This PR adds some utility to JSONDatabase to accelerate its loading/saving time.
1 parent d0da0b9 commit 0df6961

File tree

7 files changed

+526
-134
lines changed

7 files changed

+526
-134
lines changed

python/tvm/meta_schedule/utils.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
# under the License.
1717
"""Utilities for meta schedule"""
1818
import ctypes
19-
import json
2019
import logging
2120
import os
2221
import shutil
2322
from contextlib import contextmanager
24-
from typing import Any, List, Dict, Callable, Optional, Union
23+
from typing import Any, Callable, Dict, List, Optional, Union
2524

2625
import psutil # type: ignore
2726
from tvm._ffi import get_global_func, register_func
@@ -296,31 +295,6 @@ def _json_de_tvm(obj: Any) -> Any:
296295
raise TypeError("Not supported type: " + str(type(obj)))
297296

298297

299-
@register_func("meta_schedule.json_obj2str")
300-
def json_obj2str(json_obj: Any) -> str:
301-
json_obj = _json_de_tvm(json_obj)
302-
return json.dumps(json_obj)
303-
304-
305-
@register_func("meta_schedule.batch_json_str2obj")
306-
def batch_json_str2obj(json_strs: List[str]) -> List[Any]:
307-
"""Covert a list of JSON strings to a list of json objects.
308-
Parameters
309-
----------
310-
json_strs : List[str]
311-
The list of JSON strings
312-
Returns
313-
-------
314-
result : List[Any]
315-
The list of json objects
316-
"""
317-
return [
318-
json.loads(json_str)
319-
for json_str in map(str.strip, json_strs)
320-
if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//"))
321-
]
322-
323-
324298
def shash2hex(mod: IRModule) -> str:
325299
"""Get the structural hash of a module.
326300

src/meta_schedule/arg_info.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) {
8888
dtype = runtime::String2DLDataType(dtype_str);
8989
}
9090
// Load json[2] => shape
91-
shape = Downcast<Array<Integer>>(json_array->at(2));
91+
shape = AsIntArray(json_array->at(2));
9292
} catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
9393
LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
9494
<< "\nThe error is: " << e.what();

src/meta_schedule/database/database.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w
115115
CHECK(json_array && json_array->size() == 4);
116116
// Load json[1] => run_secs
117117
if (json_array->at(1).defined()) {
118-
run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
118+
run_secs = AsFloatArray(json_array->at(1));
119119
}
120120
// Load json[2] => target
121121
if (json_array->at(2).defined()) {

0 commit comments

Comments
 (0)