Skip to content

Commit 38e3ac9

Browse files
committed
[Driver] Single-module lowering flow in driver_api.cc
Prior to this commit, a build that used multiple targets needed to provide `tvm::build` with a `Map<Target, IRModule>` specifying which target should be used to compile each `IRModule`. As a result, lowering passes could not introduce new targets based on a PrimFunc's content (e.g. a `with T.target()` frame to delegate out to another device), nor simplify based on cross-device subroutines (e.g. simplify a host-side conditional based on the known output of a device-side internal subroutine). This commit makes the `tvm::attr::kTarget` attribute (`"target"`) be the single source of truth for where a `PrimFunc` will be executed. Other existing methods for specifying the target (the `target` parameter for `tvm.build`, the keys in a `Map<Target,IRModule>`, the parameter to the pass `tir::transform::BindTarget`) are still accepted as inputs, and may provide a default value for `tvm::attr::kTarget` if the attribute is missing, but may not overwrite the target attribute. This is part of a series of commits to simplify the handling of multi-target builds.
1 parent daa37e7 commit 38e3ac9

File tree

21 files changed

+756
-118
lines changed

21 files changed

+756
-118
lines changed

apps/extension/tests/test_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_ext_dev():
3939
def check_llvm():
4040
if not tvm.testing.device_enabled("llvm"):
4141
return
42-
f = tvm.build(s, [A, B], "ext_dev", "llvm")
42+
f = tvm.build(s, [A, B], "ext_dev", "ext_dev")
4343
dev = tvm.ext_dev(0)
4444
# launch the kernel.
4545
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)

include/tvm/driver/driver_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ using tvm::transform::Pass;
5454
* \param target The device Target.
5555
* \return The composite Pass for the fused module.
5656
// */
57-
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
57+
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod,
58+
Optional<Target> target = NullOpt);
5859

5960
/*!
6061
* \brief Configures and returns the composite Pass for the device Target after device/host from

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
209209
primfunc = tir_mod["main"]
210210
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
211211
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
212-
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
212+
primfunc = primfunc.with_attr(
213+
"target", tvm.target.Target(compiler_name, host=compiler_name)
214+
)
213215
return primfunc
214216

215217
def __call__(self, *args, **kwargs):

src/driver/driver_api.cc

Lines changed: 142 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
277277
return pass_list;
278278
}
279279

280-
IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
281-
auto optimize = tvm::transform::Sequential(pass_list);
282-
mod = optimize(std::move(mod));
283-
return mod;
284-
}
285-
286-
IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
287-
mod = seq(std::move(mod));
288-
return mod;
289-
}
290-
291280
// Convert te schedule to IRModule
292281
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
293282
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
@@ -340,7 +329,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
340329

341330
IRModule LowerModule(IRModule mod, bool simple_mode) {
342331
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
343-
return LowerWithPassList(std::move(mod), pass_list);
332+
tvm::transform::Sequential optimize(pass_list, "tvm.lower");
333+
return optimize(std::move(mod));
344334
}
345335

346336
TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
@@ -357,10 +347,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_
357347
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
358348
}
359349
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
360-
361-
// Get the pass list
362-
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
363-
return LowerWithPassList(std::move(mod), pass_list);
350+
return LowerModule(mod, simple_mode);
364351
}
365352

366353
TVM_REGISTER_GLOBAL("driver.lower_primfunc")
@@ -382,9 +369,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std
382369
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
383370
GlobalVarSupply global_var_supply, bool simple_mode) {
384371
IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply);
385-
// Get the legacy TE pass list
386-
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
387-
return LowerWithPassList(mod, pass_list);
372+
return LowerModule(mod, simple_mode);
388373
}
389374

390375
TVM_REGISTER_GLOBAL("driver.lower_schedule")
@@ -401,35 +386,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
401386
simple_mode);
402387
});
403388

404-
/**
405-
* This function takes the input module that contains both the device and host opts.
406-
* Then, it applies transformation on the original module before splitting into separate modules for
407-
* device and host. Then it also applies transformations on the new splitted modules.
408-
*/
409-
std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
410-
const Target& target_host_arg) {
411-
Target target = target_arg, target_host = target_host_arg;
412-
CheckAndUpdateHostConsistency(&target, &target_host);
413-
414-
ICHECK(mod_mixed.defined()) << "This module must be defined";
389+
IRModule MergeModules(const Map<Target, IRModule>& inputs) {
390+
if (inputs.size() == 1) {
391+
auto [target, mod] = *inputs.begin();
392+
return tir::transform::BindTarget(target)(mod);
393+
}
415394

416-
mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
395+
// Take the attrs from the first module so the eventual modules have them.
396+
IRModule first_module = (*inputs.begin()).second;
397+
IRModule merged = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);
417398

418-
IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
399+
for (auto [target, mod] : inputs) {
400+
mod = tir::transform::BindTarget(target)(mod);
401+
merged->Update(mod);
402+
}
419403

420-
IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
404+
return merged;
405+
}
421406

422-
auto keys = target->GetKeys();
407+
Map<Target, IRModule> SplitModule(const IRModule& module) {
408+
Map<String, IRModule> split;
423409

424-
CheckAndUpdateHostConsistency(&target, &target_host);
410+
for (auto [gvar, base_func] : module->functions) {
411+
auto target_str = base_func->GetAttr<Target>(tvm::attr::kTarget).value()->str();
412+
if (auto it = split.find(target_str); it != split.end()) {
413+
(*it).second->Add(gvar, base_func);
414+
} else {
415+
split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs));
416+
}
417+
}
425418

426-
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
427-
if (target_is_gpu && device_mod->functions.size() == 0) {
428-
DLOG(WARNING) << "Specified target " << target->str()
429-
<< " but cannot find device code. Did you forget to bind?";
419+
Map<Target, IRModule> out;
420+
for (auto [str, mod] : split) {
421+
out.Set(Target(str), mod);
430422
}
431423

432-
return {host_mod, device_mod};
424+
return out;
433425
}
434426

435427
/*!
@@ -476,52 +468,86 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
476468
// Update target host for all targets
477469
CheckAndUpdateHostConsistency(&inputs, &target_host);
478470

479-
// Take the attrs from the first module so the eventual modules have them.
480-
// Ideally this would just be one unified module all the way through;
481-
IRModule first_module = (*inputs.begin()).second;
482-
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);
483-
484-
ICHECK(mhost_all.defined()) << "The host module must be defined";
485-
486-
for (const auto& it : inputs) {
487-
if (it.second.defined()) {
488-
const Target& target = it.first;
489-
const IRModule& ir_module = it.second;
490-
auto pair = SplitMixedModule(ir_module, target, target_host);
491-
auto& host_mod = pair.first;
492-
auto& device_mod = pair.second;
493-
494-
ICHECK(host_mod.defined()) << "The split host module must be defined";
495-
496-
ICHECK(mhost_all.defined()) << "The host module must be defined";
497-
498-
// We don't want library modules going back into host codegen
499-
// unless they're supposed to. Here if we overrode the target host
500-
// to allow lowering previously we check that it's meant to be placed
501-
// back into the host Module.
502-
bool overrides_host_target =
503-
target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
504-
bool non_host_target_kind = target->kind != target_host->kind;
505-
if (overrides_host_target && non_host_target_kind) {
506-
device_modules.push_back(codegen::Build(host_mod, it.first));
507-
} else {
508-
mhost_all->Update(host_mod);
471+
auto has_gpu_function = [](const IRModule& mod) -> bool {
472+
for (const auto& [gvar, func] : mod->functions) {
473+
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
474+
if (target.value()->HasKey("gpu")) {
475+
return true;
476+
}
477+
}
478+
}
479+
return false;
480+
};
481+
482+
IRModule merged = MergeModules(inputs);
483+
484+
bool contains_gpu_function_pre = has_gpu_function(merged);
485+
merged = MixedModulePassManager(merged)(merged);
486+
bool contains_gpu_function_post = has_gpu_function(merged);
487+
if (contains_gpu_function_pre && !contains_gpu_function_post) {
488+
DLOG(WARNING) << "Specified GPU targets, "
489+
<< "but cannot find device code. Did you forget to bind?";
490+
}
491+
492+
Map<Target, IRModule> split = SplitModule(merged);
493+
494+
Map<Target, runtime::Module> built;
495+
for (const auto& [target, mod] : split) {
496+
built.Set(target, codegen::Build(mod, target));
497+
}
498+
499+
auto host_target = [&]() -> Target {
500+
// All targets that contain a kIsEntryFunc=True function
501+
Array<Target> targets_with_entry_func;
502+
503+
// All targets that can run on the CPU and contain at least one
504+
// function without kIsEntryFunc=False.
505+
Array<Target> cpu_targets;
506+
for (const auto& [target, mod] : split) {
507+
bool contains_entry_func = false;
508+
bool may_contain_entry_func = false;
509+
for (const auto& [gvar, func] : mod->functions) {
510+
Optional<Bool> is_entry_func = func->attrs.GetAttr<Bool>(tvm::tir::attr::kIsEntryFunc);
511+
if (is_entry_func.defined() && is_entry_func.value()->value) {
512+
contains_entry_func = true;
513+
} else if (!is_entry_func.defined()) {
514+
may_contain_entry_func = true;
515+
}
516+
}
517+
518+
if (contains_entry_func) {
519+
targets_with_entry_func.push_back(target);
509520
}
510521

511-
if (device_mod->functions.size() != 0) {
512-
device_modules.push_back(codegen::Build(device_mod, it.first));
522+
if (may_contain_entry_func && target->HasKey("cpu")) {
523+
cpu_targets.push_back(target);
513524
}
514525
}
515-
}
516526

517-
runtime::Module mhost = codegen::Build(mhost_all, target_host);
518-
for (const auto& it : device_modules) {
519-
if (it.operator->()) {
520-
mhost.Import(it);
527+
if (targets_with_entry_func.size()) {
528+
ICHECK_EQ(targets_with_entry_func.size(), 1)
529+
<< "Expected at most one function "
530+
<< "annotated with tvm::tir::attr::kIsEntryFunc "
531+
<< "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), "
532+
<< "but found: " << targets_with_entry_func;
533+
return targets_with_entry_func[0];
534+
} else if (cpu_targets.size() == 1) {
535+
return cpu_targets[0];
536+
} else {
537+
LOG(FATAL) << "Could not determine which target is the host. "
538+
<< "No function was annotated with tvm::tir::attr::kIsEntryFunc (\""
539+
<< tvm::tir::attr::kIsEntryFunc << "\"), "
540+
<< "and " << cpu_targets.size() << " targets have the 'cpu' key";
521541
}
522-
}
542+
}();
523543

524-
return mhost;
544+
auto runtime_module = built[host_target];
545+
for (const auto& [target, mod] : built) {
546+
if (!mod.same_as(runtime_module)) {
547+
runtime_module.Import(mod);
548+
}
549+
}
550+
return runtime_module;
525551
}
526552

527553
TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
@@ -562,18 +588,20 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
562588
return TIRToRuntime(inputs, target_host);
563589
}
564590

565-
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
591+
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional<Target> target) {
566592
transform::PassContext pass_ctx = transform::PassContext::Current();
567593

568594
Array<Pass> mixed_pass_list;
569595

596+
if (target) {
597+
mixed_pass_list.push_back(tir::transform::BindTarget(target.value()));
598+
}
599+
570600
// VerifyVTCMLimit must occur before LowerVtcmAlloc
571601
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
572602
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
573603
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
574604

575-
mixed_pass_list.push_back(tir::transform::BindTarget(target));
576-
577605
mixed_pass_list.push_back(tir::transform::VerifyMemory());
578606

579607
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
@@ -619,7 +647,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
619647

620648
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
621649

622-
return transform::Sequential(mixed_pass_list);
650+
// Only applies to the device functions, identified by inspection of
651+
// each function's tvm::attr::kTarget attribute.
652+
mixed_pass_list.push_back(tir::transform::LowerWarpMemory());
653+
654+
// Only applies to the host functions, identified by inspection of
655+
// each function's tvm::attr::kTarget attribute.
656+
mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin());
657+
658+
// Apply to both host and device functions
659+
mixed_pass_list.push_back(tir::transform::Simplify());
660+
mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes());
661+
mixed_pass_list.push_back(tir::transform::LowerIntrin());
662+
mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
663+
664+
// Only applies to the host functions, identified by inspection of
665+
// each function's tvm::attr::kTarget attribute.
666+
mixed_pass_list.push_back(tir::transform::CombineContextCall());
667+
if (pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value()) {
668+
mixed_pass_list.push_back(tir::transform::InstallDebugSpans());
669+
}
670+
671+
return transform::Sequential(mixed_pass_list, "tvm.build");
623672
}
624673

625674
TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
@@ -628,6 +677,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
628677
});
629678

630679
transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
680+
LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. "
681+
<< "All lowering passes are now included "
682+
<< "as part of driver.mixed_mod_passes.";
683+
631684
transform::PassContext pass_ctx = transform::PassContext::Current();
632685
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();
633686

@@ -653,7 +706,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
653706
host_pass_list.push_back(tir::transform::InstallDebugSpans());
654707
}
655708

656-
return transform::Sequential(host_pass_list);
709+
return transform::Sequential(host_pass_list, "tir.host_mod_passes");
657710
}
658711

659712
TVM_REGISTER_GLOBAL("driver.host_mod_passes")
@@ -662,6 +715,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes")
662715
});
663716

664717
transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
718+
LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. "
719+
<< "All lowering passes are now included "
720+
<< "as part of driver.mixed_mod_passes.";
721+
665722
Array<Pass> device_pass_list;
666723
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
667724
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
@@ -677,7 +734,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target)
677734
device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
678735
device_pass_list.push_back(tir::transform::LowerIntrin());
679736

680-
return transform::Sequential(device_pass_list);
737+
return transform::Sequential(device_pass_list, "tir.device_mod_passes");
681738
}
682739

683740
TVM_REGISTER_GLOBAL("driver.device_mod_passes")

src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
6464
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
6565
: ir_module_(ir_module),
6666
host_target_(host_target),
67-
custom_target_(Target("example_target_hook")) {}
67+
custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {}
6868

6969
IRModule Mutate() {
7070
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");

src/target/llvm/llvm_module.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
459459
}
460460
}
461461

462-
TVM_REGISTER_GLOBAL("target.build.llvm")
463-
.set_body_typed([](IRModule mod, Target target) -> runtime::Module {
464-
auto n = make_object<LLVMModuleNode>();
465-
n->Init(mod, target);
466-
return runtime::Module(n);
467-
});
462+
namespace {
463+
runtime::Module BuildLLVM(IRModule mod, Target target) {
464+
auto n = make_object<LLVMModuleNode>();
465+
n->Init(mod, target);
466+
return runtime::Module(n);
467+
}
468+
} // namespace
469+
470+
TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM);
471+
TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM);
468472

469473
TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
470474
.set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {

0 commit comments

Comments
 (0)