From cc6481d3ae6719a63994a656afd8905421cf9bb7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 10 Jun 2022 22:02:26 -0700 Subject: [PATCH] [MetaSchedule] JSONDatabase Utilities --- python/tvm/meta_schedule/utils.py | 28 +- src/meta_schedule/arg_info.cc | 2 +- src/meta_schedule/database/database.cc | 2 +- src/meta_schedule/database/database_utils.cc | 377 ++++++++++++++++++ src/meta_schedule/database/json_database.cc | 80 +++- src/meta_schedule/utils.h | 103 +++-- .../unittest/test_meta_schedule_database.py | 68 ++-- 7 files changed, 526 insertions(+), 134 deletions(-) create mode 100644 src/meta_schedule/database/database_utils.cc diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 919a29e6cf6c..26bf20670955 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -16,12 +16,11 @@ # under the License. """Utilities for meta schedule""" import ctypes -import json import logging import os import shutil from contextlib import contextmanager -from typing import Any, List, Dict, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import psutil # type: ignore from tvm._ffi import get_global_func, register_func @@ -296,31 +295,6 @@ def _json_de_tvm(obj: Any) -> Any: raise TypeError("Not supported type: " + str(type(obj))) -@register_func("meta_schedule.json_obj2str") -def json_obj2str(json_obj: Any) -> str: - json_obj = _json_de_tvm(json_obj) - return json.dumps(json_obj) - - -@register_func("meta_schedule.batch_json_str2obj") -def batch_json_str2obj(json_strs: List[str]) -> List[Any]: - """Covert a list of JSON strings to a list of json objects. - Parameters - ---------- - json_strs : List[str] - The list of JSON strings - Returns - ------- - result : List[Any] - The list of json objects - """ - return [ - json.loads(json_str) - for json_str in map(str.strip, json_strs) - if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//")) - ] - - def shash2hex(mod: IRModule) -> str: """Get the structural hash of a module. diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 104662b6aad0..9b225e8bea99 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -88,7 +88,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { dtype = runtime::String2DLDataType(dtype_str); } // Load json[2] => shape - shape = Downcast>(json_array->at(2)); + shape = AsIntArray(json_array->at(2)); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj << "\nThe error is: " << e.what(); diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 86d999e4fdf5..9905ff73c792 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -115,7 +115,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w CHECK(json_array && json_array->size() == 4); // Load json[1] => run_secs if (json_array->at(1).defined()) { - run_secs = Downcast>(json_array->at(1)); + run_secs = AsFloatArray(json_array->at(1)); } // Load json[2] => target if (json_array->at(2).defined()) { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc new file mode 100644 index 000000000000..278c5267ea93 --- /dev/null +++ b/src/meta_schedule/database/database_utils.cc @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "../../support/str_escape.h" +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { + if (!json_obj.defined()) { + os << "null"; + } else if (const auto* int_imm = json_obj.as()) { + if (int_imm->dtype == DataType::Bool()) { + if (int_imm->value) { + os << "true"; + } else { + os << "false"; + } + } else { + os << int_imm->value; + } + } else if (const auto* float_imm = json_obj.as()) { + os << std::setprecision(20) << float_imm->value; + } else if (const auto* str = json_obj.as()) { + os << '"' << support::StrEscape(str->data, str->size) << '"'; + } else if (const auto* array = json_obj.as()) { + os << "["; + int n = array->size(); + for (int i = 0; i < n; ++i) { + if (i != 0) { + os << ","; + } + JSONDumps(array->at(i), os); + } + os << "]"; + } else if (const auto* dict = json_obj.as()) { + int n = dict->size(); + std::vector> key_values; + key_values.reserve(n); + for (const auto& kv : *dict) { + if (const auto* k = kv.first.as()) { + key_values.emplace_back(GetRef(k), kv.second); + } else { + LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " + << kv.first->GetTypeKey(); + } + } + std::sort(key_values.begin(), key_values.end()); + os << "{"; + for (int i = 0; i < n; ++i) { + const auto& kv = key_values[i]; + if (i != 0) { + os << ","; + } + os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"'; + os << ":"; + JSONDumps(kv.second, os); + } + os << "}"; + } else { + LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj->GetTypeKey(); + } +} + +std::string JSONDumps(ObjectRef json_obj) { + std::ostringstream os; + JSONDumps(json_obj, os); + return os.str(); +} + +class JSONTokenizer { + public: + enum class TokenType : int32_t { + kEOF = 0, // end of file + kNull = 1, // null + kTrue = 2, // true + kFalse = 3, // false + kLeftSquare = 4, // [ + kRightSquare = 5, // ] + kLeftCurly = 6, // { + kRightCurly = 7, // } + kComma = 8, // , + kColon = 9, // : + kInteger = 10, // integers + kFloat = 11, // floating point numbers + kString = 12, // string + }; + + struct Token { + TokenType type; + ObjectRef value{nullptr}; + }; + + explicit JSONTokenizer(const char* st, const char* ed) : cur_(st), end_(ed) {} + + Token Next() { + for (; cur_ != end_ && std::isspace(*cur_); ++cur_) { + } + if (cur_ == end_) return Token{TokenType::kEOF}; + if (NextLeftSquare()) return Token{TokenType::kLeftSquare}; + if (NextRightSquare()) return Token{TokenType::kRightSquare}; + if (NextLeftCurly()) return Token{TokenType::kLeftCurly}; + if (NextRightCurly()) return Token{TokenType::kRightCurly}; + if (NextComma()) return Token{TokenType::kComma}; + if (NextColon()) return Token{TokenType::kColon}; + if (NextNull()) return Token{TokenType::kNull}; + if (NextTrue()) return Token{TokenType::kTrue}; + if (NextFalse()) return Token{TokenType::kFalse}; + Token token; + if (NextString(&token)) return token; + if (NextNumber(&token)) return token; + LOG(FATAL) << "ValueError: Cannot tokenize: " << std::string(cur_, end_); + throw; + } + + private: + bool NextLeftSquare() { return NextLiteral('['); } + bool NextRightSquare() { return NextLiteral(']'); } + bool NextLeftCurly() { return NextLiteral('{'); } + bool NextRightCurly() { return NextLiteral('}'); } + bool NextComma() { return NextLiteral(','); } + bool NextColon() { return NextLiteral(':'); } + bool NextNull() { return NextLiteral("null", 4); } + bool NextTrue() { return NextLiteral("true", 4); } + bool NextFalse() { return NextLiteral("false", 5); } + + bool NextNumber(Token* token) { + using runtime::DataType; + bool is_float = false; + const char* st = cur_; + for (; cur_ != end_; ++cur_) { + if (std::isdigit(*cur_) || *cur_ == '+' || *cur_ == '-') { + continue; + } else if (*cur_ == '.' || *cur_ == 'e' || *cur_ == 'E') { + is_float = true; + } else { + break; + } + } + if (st == cur_) { + return false; + } + // TODO(@junrushao1994): error checking + if (is_float) { + *token = Token{TokenType::kFloat, + FloatImm(DataType::Float(64), // + std::stod(std::string(st, cur_)))}; + } else { + *token = Token{TokenType::kInteger, // + Integer(std::stoi(std::string(st, cur_)))}; + } + return true; + } + + bool NextString(Token* token) { + if (cur_ == end_ || *cur_ != '"') return false; + ++cur_; + std::string str; + for (; cur_ != end_ && *cur_ != '\"'; ++cur_) { + if (*cur_ != '\\') { + str.push_back(*cur_); + continue; + } + ++cur_; + if (cur_ == end_) { + LOG(FATAL) << "ValueError: Unexpected end of string: \\"; + throw; + } + switch (*cur_) { + case '\"': + str.push_back('\"'); + break; + case '\\': + str.push_back('\\'); + break; + case '/': + str.push_back('/'); + break; + case 'b': + str.push_back('\b'); + break; + case 'f': + str.push_back('\f'); + break; + case 'n': + str.push_back('\n'); + break; + case 'r': + str.push_back('\r'); + break; + case 't': + str.push_back('\t'); + break; + default: + LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_; + } + } + if (cur_ == end_) { + LOG(FATAL) << "ValueError: Unexpected end of string"; + } + ++cur_; + *token = Token{TokenType::kString, String(str)}; + return true; + } + + bool NextLiteral(char c) { + if (cur_ != end_ && *cur_ == c) { + ++cur_; + return true; + } + return false; + } + + bool NextLiteral(const char* str, int len) { + if (cur_ + len <= end_ && std::strncmp(cur_, str, len) == 0) { + cur_ += len; + return true; + } + return false; + } + /*! \brief The current pointer */ + const char* cur_; + /*! \brief End of the string */ + const char* end_; + + friend class JSONParser; +}; + +class JSONParser { + public: + using TokenType = JSONTokenizer::TokenType; + using Token = JSONTokenizer::Token; + + explicit JSONParser(const char* st, const char* ed) : tokenizer_(st, ed) {} + + ObjectRef Get() { + Token token = tokenizer_.Next(); + if (token.type == TokenType::kEOF) { + return ObjectRef(nullptr); + } + return ParseObject(std::move(token)); + } + + private: + ObjectRef ParseObject(Token token) { + switch (token.type) { + case TokenType::kNull: + return ObjectRef(nullptr); + case TokenType::kTrue: + return Bool(true); + case TokenType::kFalse: + return Bool(false); + case TokenType::kLeftSquare: + return ParseArray(); + case TokenType::kLeftCurly: + return ParseDict(); + case TokenType::kString: + case TokenType::kInteger: + case TokenType::kFloat: + return token.value; + case TokenType::kRightSquare: + LOG(FATAL) << "ValueError: Unexpected token: ]"; + case TokenType::kRightCurly: + LOG(FATAL) << "ValueError: Unexpected token: }"; + case TokenType::kComma: + LOG(FATAL) << "ValueError: Unexpected token: ,"; + case TokenType::kColon: + LOG(FATAL) << "ValueError: Unexpected token: :"; + case TokenType::kEOF: + LOG(FATAL) << "ValueError: Unexpected EOF"; + default: + throw; + } + } + + Array ParseArray() { + bool is_first = true; + Array results; + for (;;) { + Token token; + if (is_first) { + is_first = false; + token = Token{TokenType::kComma}; + } else { + token = tokenizer_.Next(); + } + // Three cases overall: + // - Case 1. 1 token: "]" + // - Case 2. 2 tokens: ",", "]" + // - Case 3. 2 tokens: ",", "obj" + if (token.type == TokenType::kRightSquare) { // Case 1 + break; + } else if (token.type == TokenType::kComma) { + token = tokenizer_.Next(); + if (token.type == TokenType::kRightSquare) { // Case 2 + break; + } + // Case 3 + results.push_back(ParseObject(std::move(token))); + continue; + } else { + LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; + } + } + return results; + } + + Map ParseDict() { + bool is_first = true; + Map results; + for (;;) { + Token token; + if (is_first) { + is_first = false; + token = Token{TokenType::kComma}; + } else { + token = tokenizer_.Next(); + } + // Three cases overall: + // - Case 1. 1 token: "}" + // - Case 2. 2 tokens: ",", "}" + // - Case 3. 2 tokens: ",", "key", ":", "value" + if (token.type == TokenType::kRightCurly) { // Case 1 + break; + } else if (token.type == TokenType::kComma) { + token = tokenizer_.Next(); + if (token.type == TokenType::kRightCurly) { // Case 2 + break; + } + // Case 3 + ObjectRef key = ParseObject(std::move(token)); + ICHECK(key->IsInstance()) + << "ValueError: key must be a string, but gets: " << key; + token = tokenizer_.Next(); + CHECK(token.type == TokenType::kColon) + << "ValueError: Unexpected token before: " << tokenizer_.cur_; + ObjectRef value = ParseObject(tokenizer_.Next()); + results.Set(Downcast(key), value); + continue; + } else { + LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; + } + } + return results; + } + + JSONTokenizer tokenizer_; +}; + +ObjectRef JSONLoads(std::string str) { + const char* st = str.c_str(); + const char* ed = st + str.length(); + return JSONParser(st, ed).Get(); +} + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 155d223217da..4f5bd9b13613 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include "../utils.h" @@ -46,6 +47,45 @@ struct SortTuningRecordByMeanRunSecs { } }; +/*! + * \brief Read lines from a json file. + * \param path The path to the json file. + * \param num_lines The number of threads used to concurrently parse the lines. + * \param allow_missing Whether to create new file when the given path is not found. + * \return An array containing lines read from the json file. + */ +std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing) { + std::ifstream is(path); + if (is.good()) { + std::vector json_strs; + for (std::string str; std::getline(is, str);) { + json_strs.push_back(str); + } + int n = json_strs.size(); + std::vector json_objs; + json_objs.resize(n); + support::parallel_for_dynamic(0, n, num_threads, [&](int thread_id, int task_id) { + json_objs[task_id] = JSONLoads(json_strs[task_id]); + }); + return json_objs; + } + CHECK(allow_missing) << "ValueError: File doesn't exist: " << path; + std::ofstream os(path); + CHECK(os.good()) << "ValueError: Cannot create new file: " << path; + return {}; +} + +/*! + * \brief Append a line to a json file. + * \param path The path to the json file. + * \param line The line to append. + */ +void JSONFileAppendLine(const String& path, const std::string& line) { + std::ofstream os(path, std::ofstream::app); + CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; + os << line << std::endl; +} + /*! \brief The default database implementation, which mimics two database tables with two files. */ class JSONDatabaseNode : public DatabaseNode { public: @@ -83,7 +123,7 @@ class JSONDatabaseNode : public DatabaseNode { // If `mod` is new in `workloads2idx_`, append it to the workload file if (inserted) { it->second = static_cast(this->workloads2idx_.size()) - 1; - JSONFileAppendLine(this->path_workload, JSONObj2Str(workload->AsJSON())); + JSONFileAppendLine(this->path_workload, JSONDumps(workload->AsJSON())); } return it->first; } @@ -91,7 +131,7 @@ class JSONDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) { this->tuning_records_.insert(record); JSONFileAppendLine(this->path_tuning_record, - JSONObj2Str(Array{ + JSONDumps(Array{ /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), /*tuning_record=*/record->AsJSON() // })); @@ -121,11 +161,12 @@ class JSONDatabaseNode : public DatabaseNode { Database Database::JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing) { + int num_threads = std::thread::hardware_concurrency(); ObjectPtr n = make_object(); // Load `n->workloads2idx_` from `path_workload` std::vector workloads; { - Array json_objs = JSONStr2Obj(JSONFileReadLines(path_workload, allow_missing)); + std::vector json_objs = JSONFileReadLines(path_workload, num_threads, allow_missing); int n_objs = json_objs.size(); n->workloads2idx_.reserve(n_objs); workloads.reserve(n_objs); @@ -137,20 +178,25 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, } // Load `n->tuning_records_` from `path_tuning_record` { - Array json_objs = JSONStr2Obj(JSONFileReadLines(path_tuning_record, allow_missing)); - for (const ObjectRef& json_obj : json_objs) { - int workload_index = -1; - ObjectRef tuning_record{nullptr}; - try { - const ArrayNode* arr = json_obj.as(); - ICHECK_EQ(arr->size(), 2); - workload_index = Downcast(arr->at(0)); - tuning_record = arr->at(1); - } catch (std::runtime_error& e) { - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); - } - n->tuning_records_.insert(TuningRecord::FromJSON(tuning_record, workloads[workload_index])); + std::vector json_objs = + JSONFileReadLines(path_tuning_record, num_threads, allow_missing); + std::vector records; + records.resize(json_objs.size(), TuningRecord{nullptr}); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 2); + records[task_id] = TuningRecord::FromJSON(arr->at(1), // + workloads[Downcast(arr->at(0))]); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + for (const TuningRecord& record : records) { + n->tuning_records_.insert(record); } } n->path_workload = path_workload; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index be7745f23d2c..40c301c6174f 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -107,38 +107,6 @@ class PyLogMessage { /*! \brief The type of the random state */ using TRandState = support::LinearCongruentialEngine::TRandState; -/*! - * \brief Read lines from a json file. - * \param path The path to the json file. - * \param allow_missing Whether to create new file when the given path is not found. - * \return An array containing lines read from the json file. - */ -inline Array JSONFileReadLines(const String& path, bool allow_missing) { - std::ifstream is(path); - if (is.good()) { - Array results; - for (std::string str; std::getline(is, str);) { - results.push_back(str); - } - return results; - } - CHECK(allow_missing) << "ValueError: File doesn't exist: " << path; - std::ofstream os(path); - CHECK(os.good()) << "ValueError: Cannot create new file: " << path; - return {}; -} - -/*! - * \brief Append a line to a json file. - * \param path The path to the json file. - * \param line The line to append. - */ -inline void JSONFileAppendLine(const String& path, const std::string& line) { - std::ofstream os(path, std::ofstream::app); - CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; - os << line << std::endl; -} - /*! * \brief Get the base64 encoded result of a string. * \param str The string to encode. @@ -168,31 +136,18 @@ inline std::string Base64Decode(std::string str) { } /*! - * \brief Parse lines of json string into a json object. - * \param lines The lines of json string. - * \return Array of json objects parsed. - * \note The function calls the python-side json parser in runtime registry. + * \brief Parses a json string into a json object. + * \param json_str The json string. + * \return The json object */ -inline Array JSONStr2Obj(const Array& lines) { - static const runtime::PackedFunc* f_to_obj = - runtime::Registry::Get("meta_schedule.batch_json_str2obj"); - ICHECK(f_to_obj) << "IndexError: Cannot find the packed function " - "`meta_schedule.batch_json_str2obj` in the global registry"; - return (*f_to_obj)(lines); -} +ObjectRef JSONLoads(std::string json_str); /*! - * \brief Serialize a json object into a json string. - * \param json_obj The json object to serialize. - * \return A string containing the serialized json object. - * \note The function calls the python-side json obj serializer in runtime registry. + * \brief Dumps a json object into a json string. + * \param json_obj The json object. + * \return The json string */ -inline String JSONObj2Str(const ObjectRef& json_obj) { - static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("meta_schedule.json_obj2str"); - ICHECK(f_to_str) << "IndexError: Cannot find the packed function " - "`meta_schedule.json_obj2str` in the global registry"; - return (*f_to_str)(json_obj); -} +std::string JSONDumps(ObjectRef json_obj); /*! * \brief Converts a structural hash code to string @@ -447,6 +402,48 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) { } } +/*! + * \brief Convert the given object to an array of floating point numbers + * \param obj The object to be converted + * \return The array of floating point numbers + */ +inline Array AsFloatArray(const ObjectRef& obj) { + const ArrayNode* arr = obj.as(); + ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); + Array results; + results.reserve(arr->size()); + for (const ObjectRef& elem : *arr) { + if (const auto* int_imm = elem.as()) { + results.push_back(FloatImm(DataType::Float(32), int_imm->value)); + } else if (const auto* float_imm = elem.as()) { + results.push_back(FloatImm(DataType::Float(32), float_imm->value)); + } else { + LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey(); + } + } + return results; +} + +/*! + * \brief Convert the given object to an array of integers + * \param obj The object to be converted + * \return The array of integers + */ +inline Array AsIntArray(const ObjectRef& obj) { + const ArrayNode* arr = obj.as(); + ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); + Array results; + results.reserve(arr->size()); + for (const ObjectRef& elem : *arr) { + if (const auto* int_imm = elem.as()) { + results.push_back(Integer(int_imm->value)); + } else { + LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); + } + } + return results; +} + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index 1edfbe6c7a78..ff0f350d8914 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -17,20 +17,18 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring """Test Meta Schedule Database""" import os.path as osp -import sys import tempfile from typing import Callable -import pytest import tvm import tvm.testing +from tvm import meta_schedule as ms from tvm import tir from tvm.ir.module import IRModule -from tvm.meta_schedule.arg_info import ArgInfo -from tvm.meta_schedule.database import JSONDatabase, TuningRecord from tvm.script import tir as T from tvm.tir import Schedule + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @tvm.script.ir_module @@ -92,13 +90,13 @@ def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Sched return sch -def _create_tmp_database(tmpdir: str) -> JSONDatabase: +def _create_tmp_database(tmpdir: str) -> ms.database.JSONDatabase: path_workload = osp.join(tmpdir, "workloads.json") path_tuning_record = osp.join(tmpdir, "tuning_records.json") - return JSONDatabase(path_workload, path_tuning_record) + return ms.database.JSONDatabase(path_workload, path_tuning_record) -def _equal_record(a: TuningRecord, b: TuningRecord): +def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord): assert str(a.trace) == str(b.trace) assert str(a.run_secs) == str(b.run_secs) # AWAIT(@zxybazh): change to export after fixing "(bool)0" @@ -113,15 +111,15 @@ def test_meta_schedule_tuning_record_round_trip(): with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) - record = TuningRecord( + record = ms.database.TuningRecord( _create_schedule(mod, _schedule_matmul).trace, workload, [1.5, 2.5, 1.8], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ) database.commit_tuning_record(record) - new_record = TuningRecord.from_json(record.as_json(), workload) + new_record = ms.database.TuningRecord.from_json(record.as_json(), workload) _equal_record(record, new_record) @@ -138,12 +136,12 @@ def test_meta_schedule_database_has_workload(): with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) - record = TuningRecord( + record = ms.database.TuningRecord( _create_schedule(mod, _schedule_matmul).trace, workload, [1.5, 2.5, 1.8], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ) database.commit_tuning_record(record) assert len(database) == 1 @@ -156,12 +154,12 @@ def test_meta_schedule_database_add_entry(): with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) - record = TuningRecord( + record = ms.database.TuningRecord( _create_schedule(mod, _schedule_matmul).trace, workload, [1.5, 2.5, 1.8], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ) database.commit_tuning_record(record) assert len(database) == 1 @@ -176,12 +174,12 @@ def test_meta_schedule_database_missing(): database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) workload_2 = database.commit_workload(mod_2) - record = TuningRecord( + record = ms.database.TuningRecord( _create_schedule(mod, _schedule_matmul).trace, workload, [1.5, 2.5, 1.8], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ) database.commit_tuning_record(record) ret = database.get_top_k(workload_2, 3) @@ -195,47 +193,47 @@ def test_meta_schedule_database_sorting(): token = database.commit_workload(mod) trace = _create_schedule(mod, _schedule_matmul).trace records = [ - TuningRecord( + ms.database.TuningRecord( trace, token, [7.0, 8.0, 9.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [1.0, 2.0, 3.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [4.0, 5.0, 6.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [1.1, 1.2, 600.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [1.0, 100.0, 6.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [4.0, 9.0, 8.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), ] for record in records: @@ -257,31 +255,31 @@ def test_meta_schedule_database_reload(): token = database.commit_workload(mod) trace = _create_schedule(mod, _schedule_matmul).trace records = [ - TuningRecord( + ms.database.TuningRecord( trace, token, [7.0, 8.0, 9.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [1.0, 2.0, 3.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), - TuningRecord( + ms.database.TuningRecord( trace, token, [4.0, 5.0, 6.0], tvm.target.Target("llvm"), - ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), ), ] for record in records: database.commit_tuning_record(record) - new_database = JSONDatabase( # pylint: disable=unused-variable + new_database = ms.database.JSONDatabase( path_workload=database.path_workload, path_tuning_record=database.path_tuning_record, )