diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7b4b8d826f36..fdd2db9c4a81 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -71,8 +71,9 @@ jobs: python -m pytest -v tests/python/all-platform-minimal-test - name: Minimal Metal Compile-Only shell: bash -l {0} - run: >- + run: | python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile' + python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params' - name: Minimal Metal Compile-and-Run shell: bash -l {0} run: >- diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ddd7d25f3b5f..86d5956dec19 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -39,8 +39,6 @@ namespace codegen { void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); - // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("v_"); // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -57,15 +55,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) { << "};\n\n"; } -void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) { +void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { + // NOTE: There is no inter-function calls among Metal kernels. + // For now we keep the metal codegen without inter-function call + // process. + // We can switch to follow the flow with inter-function call process + // after the Metal function declaration is properly printed. + // In Metal, for PrimFuncs with signature + // def func(A: Buffer, B: Buffer, x: int, y: float) -> None + // where there are trailing pod parameters, the codegen emits a struct + // struct func_params{ x: int; y: float; } + // for the function. In the flow of inter-function call process, + // the struct will be emitted for every time a function is declared. + // So consequently there are duplicate appearances of a same struct, + // which makes the Metal compiler unable to recognize. + + // clear previous generated state. + this->InitFuncState(func); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); + // add to alloc buffer type. auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - os << "kernel void " << static_cast(global_symbol.value()) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; @@ -77,13 +93,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { Var v = func->params[i]; if (!v.dtype().is_handle()) break; - os << " "; + this->stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, os); + PrintStorageScope(it->second, this->stream); } - PrintType(GetType(v), os); + PrintType(GetType(v), this->stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -92,14 +108,15 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri RegisterHandleType(v.get(), prim->dtype); } } - os << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. size_t nargs = func->params.size() - num_buffer; std::string varg = name_supply_->FreshName("arg"); if (nargs != 0) { std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; - os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; + this->stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer + << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < func->params.size(); ++i) { @@ -141,16 +158,22 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri if (work_dim != 0) { // use ushort by default for now - os << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), os); - os << " blockIdx [[threadgroup_position_in_grid]],\n"; - os << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), os); - os << " threadIdx [[thread_position_in_threadgroup]]\n"; + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " blockIdx [[threadgroup_position_in_grid]],\n"; + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } thread_work_dim_ = work_dim; - os << ")"; + // the function scope. + stream << ") {\n"; + int func_scope = this->BeginScope(); + this->PrintStmt(func->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { @@ -295,6 +318,9 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + CHECK(!op->op.as()) + << "CodegenMetal does not support inter-function calls, " + << "but expression " << GetRef(op) << " calls PrimFunc " << op->op; if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; @@ -337,33 +363,28 @@ runtime::Module BuildMetal(IRModule mod, Target target) { const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); std::string fmt = fmetal_compile ? "metallib" : "metal"; - Map functions; - for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - auto calling_conv = base_func->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - - auto prim_func = Downcast(base_func); - functions.Set(gvar, prim_func); - } + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + std::string func_name = global_symbol.value(); - for (auto [gvar, prim_func] : functions) { - source_maker << "// Function: " << gvar->name_hint << "\n"; + source_maker << "// Function: " << func_name << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - for (auto [other_gvar, other_prim_func] : functions) { - cg.DeclareFunction(other_gvar, other_prim_func); - } - cg.AddFunction(gvar, prim_func); + cg.AddFunction(kv.first, f); std::string fsource = cg.Finish(); source_maker << fsource << "\n"; if (fmetal_compile) { fsource = (*fmetal_compile)(fsource, target).operator std::string(); } - smap[cg.GetFunctionName(gvar)] = fsource; + smap[func_name] = fsource; } return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 26c991e60df9..9cff3211ce44 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC { explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); - void PrintFunctionSignature(const String& function_name, const PrimFunc& func, - std::ostream& os) override; + void AddFunction(const GlobalVar& gvar, const PrimFunc& func) final; void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index dcbbba8c9c9f..b4e747a0b4d8 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -169,5 +169,28 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5) +@tvm.testing.requires_metal(support_required="compile-only") +def test_func_with_trailing_pod_params(): + from tvm.contrib import xcode # pylint: disable=import-outside-toplevel + + @T.prim_func + def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32): + for i in T.thread_binding(16, thread="threadIdx.x"): + with T.block("block"): + vi = T.axis.spatial(16, i) + B[vi] = A[vi] + x + + @tvm.register_func("tvm_callback_metal_compile") + def compile_metal(src, target): + return xcode.compile_metal(src) + + mod = tvm.IRModule({"main": func}) + + f = tvm.build(mod, target="metal") + src: str = f.imported_modules[0].get_source() + occurrences = src.count("struct func_kernel_args_t") + assert occurrences == 1, occurrences + + if __name__ == "__main__": tvm.testing.main()