diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index a93f1c66c395..076172a8b5f7 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -251,6 +251,14 @@ inline const ModuleNode* Module::operator->() const { return static_cast(get()); } +inline std::ostream& operator<<(std::ostream& out, const Module& module) { + out << "Module(type_key= "; + out << module->type_key(); + out << ")"; + + return out; +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 7292b996e4a5..46881d998404 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -53,7 +53,7 @@ HexagonHostModuleNode::HexagonHostModuleNode(std::string data, std::string fmt, PackedFunc HexagonHostModuleNode::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { LOG(FATAL) << "HexagonHostModuleNode::GetFunction is not implemented."; - return nullptr; + return PackedFunc(); } std::string HexagonHostModuleNode::GetSource(const std::string& format) { diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index ccf5e09ebcf1..85eab912024f 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -86,7 +86,6 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(); } - return nullptr; } /*! diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index a003367c3724..85dad2839a8a 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -105,7 +105,7 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr #include +#include #include +#include #include #include "../../support/str_escape.h" @@ -362,9 +364,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) { Map linked_params; PrimFunc aot_executor_fn; + std::vector> funcs; for (auto kv : mod->functions) { // Make sure that the executor function is the last one to be code generated so that all the - // symbols are available to tvm_run_func + // symbols are available to __tvm_main__ auto fun_name = std::string(kv.first->name_hint); bool is_aot_executor_fn = kv.second->GetAttr("runner_function", Bool(false)).value(); @@ -372,12 +375,26 @@ runtime::Module BuildCHost(IRModule mod, Target target) { aot_executor_fn = Downcast(kv.second); continue; } + funcs.push_back(kv); + } + // Sort functions + std::sort(funcs.begin(), funcs.end(), + [](std::pair kv_a, + std::pair kv_b) { + std::string name_hint_a = kv_a.first->name_hint; + std::string name_hint_b = kv_b.first->name_hint; + return name_hint_a < name_hint_b; + }); + + // Add all functions except __tvm_main__ + for (auto& kv : funcs) { ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); cg.AddFunction(f); } + // Add __tvm_main__ if (aot_executor_fn.defined()) { cg.AddFunction(aot_executor_fn); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 81bc70473722..7347916fcada 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -26,6 +26,7 @@ #include #include +#include #include #include "codegen_c.h" @@ -42,7 +43,13 @@ class CodeGenCHost : public CodeGenC { void InitGlobalContext(); void AddFunction(const PrimFunc& f); - + /*! + * \brief Add functions from the (unordered) range to the current module in a deterministic + * order. This helps with debugging. + * + * \param functions A vector of unordered range of current module. + */ + void AddFunctionsOrdered(std::vector> functions); void DefineModuleName(); void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc index f4cef74e8af9..9f10fd2881e7 100644 --- a/src/target/source/interface_c.cc +++ b/src/target/source/interface_c.cc @@ -90,7 +90,7 @@ class InterfaceCNode : public runtime::ModuleNode { } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - return PackedFunc(nullptr); + return PackedFunc(); } private: diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 7db5d8c83a84..80b4f1b970f3 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -171,7 +171,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { std::string GetFormat() { return fmt_; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - return PackedFunc(nullptr); + return PackedFunc(); } void SaveToFile(const std::string& file_name, const std::string& format) final { diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc new file mode 100644 index 000000000000..097de862a926 --- /dev/null +++ b/tests/cpp/c_codegen_test.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TEST(CCodegen, MainFunctionOrder) { + using namespace tvm; + using namespace tvm::te; + + std::string tvm_module_main = std::string(runtime::symbol::tvm_module_main); + + tvm::Target target_c = tvm::Target("c -keys=cpu -link-params=0"); + + const int n = 4; + Array shape{n}; + + auto A = placeholder(shape, DataType::Float(32), "A"); + auto B = placeholder(shape, DataType::Float(32), "B"); + + auto elemwise_add = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add"); + + auto fcreate = [=]() { + With llvm_scope(target_c); + return create_schedule({elemwise_add->op}); + }; + + auto args = Array({A, B, elemwise_add}); + + std::unordered_map binds; + auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds); + Map inputs = {{target_c, lowered}}; + runtime::Module module = build(inputs, Target()); + Array functions = module->GetFunction("get_func_names", false)(); + + ICHECK(functions.back().compare(tvm_module_main) == 0); +} + +auto BuildLowered(std::string op_name, tvm::Target target) { + using namespace tvm; + using namespace tvm::te; + + // The shape of input tensors. + const int n = 4; + Array shape{n}; + + auto A = placeholder(shape, DataType::Float(32), "A"); + auto B = placeholder(shape, DataType::Float(32), "B"); + + auto op = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, op_name); + + auto fcreate_s = [=]() { + With llvm_scope(target); + return create_schedule({op->op}); + }; + + auto args = Array({A, B, op}); + std::unordered_map binds; + auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds); + return lowered_s; +} + +bool IsSorted(tvm::Map inputs) { + std::vector schedule_names; + for (auto const& module : inputs) { + for (auto const& func : module.second->functions) { + schedule_names.push_back(func.first->name_hint); + } + } + return std::is_sorted(schedule_names.begin(), schedule_names.end()); +} + +TEST(CCodegen, FunctionOrder) { + using testing::_; + using testing::ElementsAre; + using testing::StrEq; + using namespace tvm; + using namespace tvm::te; + + Target target = Target("c -keys=cpu -link-params=0"); + + // add schedules in reverse order + Map inputs; + inputs.Set(Target("c -keys=cpu -link-params=0"), BuildLowered("op_2", target)); + inputs.Set(Target("c -keys=cpu -link-params=0"), BuildLowered("op_1", target)); + + for (uint32_t counter = 99; IsSorted(inputs) && counter > 0; counter--) { + std::string op_name = "op_" + std::to_string(counter); + inputs.Set(Target("c -keys=cpu -link-params=0"), BuildLowered(op_name, target)); + } + + EXPECT_FALSE(IsSorted(inputs)); + + auto module = build(inputs, Target()); + Array func_array = module->GetFunction("get_func_names", false)(); + std::vector functions{func_array.begin(), func_array.end()}; + EXPECT_THAT(functions, ElementsAre(StrEq("op_1"), _, StrEq("op_2"), _)); +}