Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
shell: bash -l {0}
run: >-
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile'
python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params'
- name: Minimal Metal Compile-and-Run
shell: bash -l {0}
run: >-
Expand Down
89 changes: 55 additions & 34 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ namespace codegen {

void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
Expand All @@ -57,15 +55,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
<< "};\n\n";
}

void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) {
void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
// NOTE: There is no inter-function calls among Metal kernels.
// For now we keep the metal codegen without inter-function call
// process.
// We can switch to follow the flow with inter-function call process
// after the Metal function declaration is properly printed.
// In Metal, for PrimFuncs with signature
// def func(A: Buffer, B: Buffer, x: int, y: float) -> None
// where there are trailing pod parameters, the codegen emits a struct
// struct func_params{ x: int; y: float; }
// for the function. In the flow of inter-function call process,
// the struct will be emitted for every time a function is declared.
// So consequently there are duplicate appearances of a same struct,
// which makes the Metal compiler unable to recognize.

// clear previous generated state.
this->InitFuncState(func);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");

// add to alloc buffer type.
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";

// Function header.
os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";

// Buffer arguments
size_t num_buffer = 0;
Expand All @@ -77,13 +93,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
Var v = func->params[i];
if (!v.dtype().is_handle()) break;
os << " ";
this->stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
PrintStorageScope(it->second, this->stream);
}
PrintType(GetType(v), os);
PrintType(GetType(v), this->stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
Expand All @@ -92,14 +108,15 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
RegisterHandleType(v.get(), prim->dtype);
}
}
os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
size_t nargs = func->params.size() - num_buffer;
std::string varg = name_supply_->FreshName("arg");
if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n";
this->stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
<< ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < func->params.size(); ++i) {
Expand Down Expand Up @@ -141,16 +158,22 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri

if (work_dim != 0) {
// use ushort by default for now
os << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
os << " blockIdx [[threadgroup_position_in_grid]],\n";
os << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
os << " threadIdx [[thread_position_in_threadgroup]]\n";
stream << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " blockIdx [[threadgroup_position_in_grid]],\n";
stream << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
thread_work_dim_ = work_dim;

os << ")";
// the function scope.
stream << ") {\n";
int func_scope = this->BeginScope();
this->PrintStmt(func->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}

void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
Expand Down Expand Up @@ -295,6 +318,9 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
}

void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
CHECK(!op->op.as<GlobalVarNode>())
<< "CodegenMetal does not support inter-function calls, "
<< "but expression " << GetRef<Call>(op) << " calls PrimFunc " << op->op;
if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
Expand Down Expand Up @@ -337,33 +363,28 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";

Map<GlobalVar, PrimFunc> functions;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

auto prim_func = Downcast<PrimFunc>(base_func);
functions.Set(gvar, prim_func);
}
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
std::string func_name = global_symbol.value();

for (auto [gvar, prim_func] : functions) {
source_maker << "// Function: " << gvar->name_hint << "\n";
source_maker << "// Function: " << func_name << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

for (auto [other_gvar, other_prim_func] : functions) {
cg.DeclareFunction(other_gvar, other_prim_func);
}
cg.AddFunction(gvar, prim_func);
cg.AddFunction(kv.first, f);

std::string fsource = cg.Finish();
source_maker << fsource << "\n";
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource, target).operator std::string();
}
smap[cg.GetFunctionName(gvar)] = fsource;
smap[func_name] = fsource;
}

return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
Expand Down
3 changes: 1 addition & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC {
explicit CodeGenMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) override;
void AddFunction(const GlobalVar& gvar, const PrimFunc& func) final;
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,28 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)


@tvm.testing.requires_metal(support_required="compile-only")
def test_func_with_trailing_pod_params():
from tvm.contrib import xcode # pylint: disable=import-outside-toplevel

@T.prim_func
def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32):
for i in T.thread_binding(16, thread="threadIdx.x"):
with T.block("block"):
vi = T.axis.spatial(16, i)
B[vi] = A[vi] + x

@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src, target):
return xcode.compile_metal(src)

mod = tvm.IRModule({"main": func})

f = tvm.build(mod, target="metal")
src: str = f.imported_modules[0].get_source()
occurrences = src.count("struct func_kernel_args_t")
assert occurrences == 1, occurrences


if __name__ == "__main__":
tvm.testing.main()