@@ -277,17 +277,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
277277 return pass_list;
278278}
279279
280- IRModule LowerWithPassList (IRModule mod, Array<tvm::transform::Pass> pass_list) {
281- auto optimize = tvm::transform::Sequential (pass_list);
282- mod = optimize (std::move (mod));
283- return mod;
284- }
285-
286- IRModule ApplyPasses (IRModule mod, transform::Sequential seq) {
287- mod = seq (std::move (mod));
288- return mod;
289- }
290-
291280// Convert te schedule to IRModule
292281IRModule ScheduleToModule (te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
293282 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
@@ -340,7 +329,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
340329
341330IRModule LowerModule (IRModule mod, bool simple_mode) {
342331 Array<transform::Pass> pass_list = CreatePassList (simple_mode);
343- return LowerWithPassList (std::move (mod), pass_list);
332+ tvm::transform::Sequential optimize (pass_list, " tvm.lower" );
333+ return optimize (std::move (mod));
344334}
345335
346336TVM_REGISTER_GLOBAL (" driver.lower_module" ).set_body_typed([](IRModule mod, bool simple_mode) {
@@ -357,10 +347,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_
357347 f = WithAttr (std::move (f), " tir.noalias" , Bool (true ));
358348 }
359349 IRModule mod = IRModule (Map<GlobalVar, BaseFunc>({{GlobalVar (name), f}}));
360-
361- // Get the pass list
362- Array<transform::Pass> pass_list = CreatePassList (simple_mode);
363- return LowerWithPassList (std::move (mod), pass_list);
350+ return LowerModule (mod, simple_mode);
364351}
365352
366353TVM_REGISTER_GLOBAL (" driver.lower_primfunc" )
@@ -382,9 +369,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std
382369 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
383370 GlobalVarSupply global_var_supply, bool simple_mode) {
384371 IRModule mod = ScheduleToModule (std::move (sch), args, name, binds, global_var_supply);
385- // Get the legacy TE pass list
386- Array<transform::Pass> pass_list = CreatePassList (simple_mode);
387- return LowerWithPassList (mod, pass_list);
372+ return LowerModule (mod, simple_mode);
388373}
389374
390375TVM_REGISTER_GLOBAL (" driver.lower_schedule" )
@@ -401,35 +386,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
401386 simple_mode);
402387 });
403388
404- /* *
405- * This function takes the input module that contains both the device and host opts.
406- * Then, it applies transformation on the original module before splitting into separate modules for
407- * device and host. Then it also applies transformations on the new splitted modules.
408- */
409- std::pair<IRModule, IRModule> SplitMixedModule (IRModule mod_mixed, const Target& target_arg,
410- const Target& target_host_arg) {
411- Target target = target_arg, target_host = target_host_arg;
412- CheckAndUpdateHostConsistency (&target, &target_host);
413-
414- ICHECK (mod_mixed.defined ()) << " This module must be defined" ;
389+ IRModule MergeModules (const Map<Target, IRModule>& inputs) {
390+ if (inputs.size () == 1 ) {
391+ auto [target, mod] = *inputs.begin ();
392+ return tir::transform::BindTarget (target)(mod);
393+ }
415394
416- mod_mixed = ApplyPasses (mod_mixed, MixedModulePassManager (mod_mixed, target));
395+ // Take the attrs from the first module so the eventual modules have them.
396+ IRModule first_module = (*inputs.begin ()).second ;
397+ IRModule merged = IRModule (Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs );
417398
418- IRModule host_mod = ApplyPasses (mod_mixed, HostModulePassManager (mod_mixed, target_host));
399+ for (auto [target, mod] : inputs) {
400+ mod = tir::transform::BindTarget (target)(mod);
401+ merged->Update (mod);
402+ }
419403
420- IRModule device_mod = ApplyPasses (mod_mixed, DeviceModulePassManager (mod_mixed, target));
404+ return merged;
405+ }
421406
422- auto keys = target->GetKeys ();
407+ Map<Target, IRModule> SplitModule (const IRModule& module ) {
408+ Map<String, IRModule> split;
423409
424- CheckAndUpdateHostConsistency (&target, &target_host);
410+ for (auto [gvar, base_func] : module ->functions ) {
411+ auto target_str = base_func->GetAttr <Target>(tvm::attr::kTarget ).value ()->str ();
412+ if (auto it = split.find (target_str); it != split.end ()) {
413+ (*it).second ->Add (gvar, base_func);
414+ } else {
415+ split.Set (target_str, IRModule ({{gvar, base_func}}, {}, {}, {}, module ->attrs ));
416+ }
417+ }
425418
426- bool target_is_gpu = std::find (keys.begin (), keys.end (), " gpu" ) != keys.end ();
427- if (target_is_gpu && device_mod->functions .size () == 0 ) {
428- DLOG (WARNING) << " Specified target " << target->str ()
429- << " but cannot find device code. Did you forget to bind?" ;
419+ Map<Target, IRModule> out;
420+ for (auto [str, mod] : split) {
421+ out.Set (Target (str), mod);
430422 }
431423
432- return {host_mod, device_mod} ;
424+ return out ;
433425}
434426
435427/* !
@@ -476,52 +468,86 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
476468 // Update target host for all targets
477469 CheckAndUpdateHostConsistency (&inputs, &target_host);
478470
479- // Take the attrs from the first module so the eventual modules have them.
480- // Ideally this would just be one unified module all the way through;
481- IRModule first_module = (*inputs.begin ()).second ;
482- IRModule mhost_all = IRModule (Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs );
483-
484- ICHECK (mhost_all.defined ()) << " The host module must be defined" ;
485-
486- for (const auto & it : inputs) {
487- if (it.second .defined ()) {
488- const Target& target = it.first ;
489- const IRModule& ir_module = it.second ;
490- auto pair = SplitMixedModule (ir_module, target, target_host);
491- auto & host_mod = pair.first ;
492- auto & device_mod = pair.second ;
493-
494- ICHECK (host_mod.defined ()) << " The split host module must be defined" ;
495-
496- ICHECK (mhost_all.defined ()) << " The host module must be defined" ;
497-
498- // We don't want library modules going back into host codegen
499- // unless they're supposed to. Here if we overrode the target host
500- // to allow lowering previously we check that it's meant to be placed
501- // back into the host Module.
502- bool overrides_host_target =
503- target->GetTargetDeviceType () == target_host->GetTargetDeviceType ();
504- bool non_host_target_kind = target->kind != target_host->kind ;
505- if (overrides_host_target && non_host_target_kind) {
506- device_modules.push_back (codegen::Build (host_mod, it.first ));
507- } else {
508- mhost_all->Update (host_mod);
471+ auto has_gpu_function = [](const IRModule& mod) -> bool {
472+ for (const auto & [gvar, func] : mod->functions ) {
473+ if (auto target = func->GetAttr <Target>(tvm::attr::kTarget )) {
474+ if (target.value ()->HasKey (" gpu" )) {
475+ return true ;
476+ }
477+ }
478+ }
479+ return false ;
480+ };
481+
482+ IRModule merged = MergeModules (inputs);
483+
484+ bool contains_gpu_function_pre = has_gpu_function (merged);
485+ merged = MixedModulePassManager (merged)(merged);
486+ bool contains_gpu_function_post = has_gpu_function (merged);
487+ if (contains_gpu_function_pre && !contains_gpu_function_post) {
488+ DLOG (WARNING) << " Specified GPU targets, "
489+ << " but cannot find device code. Did you forget to bind?" ;
490+ }
491+
492+ Map<Target, IRModule> split = SplitModule (merged);
493+
494+ Map<Target, runtime::Module> built;
495+ for (const auto & [target, mod] : split) {
496+ built.Set (target, codegen::Build (mod, target));
497+ }
498+
499+ auto host_target = [&]() -> Target {
500+ // All targets that contain a kIsEntryFunc=True function
501+ Array<Target> targets_with_entry_func;
502+
503+ // All targets that can run on the CPU and contain at least one
504+ // function without kIsEntryFunc=False.
505+ Array<Target> cpu_targets;
506+ for (const auto & [target, mod] : split) {
507+ bool contains_entry_func = false ;
508+ bool may_contain_entry_func = false ;
509+ for (const auto & [gvar, func] : mod->functions ) {
510+ Optional<Bool> is_entry_func = func->attrs .GetAttr <Bool>(tvm::tir::attr::kIsEntryFunc );
511+ if (is_entry_func.defined () && is_entry_func.value ()->value ) {
512+ contains_entry_func = true ;
513+ } else if (!is_entry_func.defined ()) {
514+ may_contain_entry_func = true ;
515+ }
516+ }
517+
518+ if (contains_entry_func) {
519+ targets_with_entry_func.push_back (target);
509520 }
510521
511- if (device_mod-> functions . size () != 0 ) {
512- device_modules .push_back (codegen::Build (device_mod, it. first ) );
522+ if (may_contain_entry_func && target-> HasKey ( " cpu " ) ) {
523+ cpu_targets .push_back (target );
513524 }
514525 }
515- }
516526
517- runtime::Module mhost = codegen::Build (mhost_all, target_host);
518- for (const auto & it : device_modules) {
519- if (it.operator ->()) {
520- mhost.Import (it);
527+ if (targets_with_entry_func.size ()) {
528+ ICHECK_EQ (targets_with_entry_func.size (), 1 )
529+ << " Expected at most one function "
530+ << " annotated with tvm::tir::attr::kIsEntryFunc "
531+ << " (\" " << tvm::tir::attr::kIsEntryFunc << " \" ), "
532+ << " but found: " << targets_with_entry_func;
533+ return targets_with_entry_func[0 ];
534+ } else if (cpu_targets.size () == 1 ) {
535+ return cpu_targets[0 ];
536+ } else {
537+ LOG (FATAL) << " Could not determine which target is the host. "
538+ << " No function was annotated with tvm::tir::attr::kIsEntryFunc (\" "
539+ << tvm::tir::attr::kIsEntryFunc << " \" ), "
540+ << " and " << cpu_targets.size () << " targets have the 'cpu' key" ;
521541 }
522- }
542+ }();
523543
524- return mhost;
544+ auto runtime_module = built[host_target];
545+ for (const auto & [target, mod] : built) {
546+ if (!mod.same_as (runtime_module)) {
547+ runtime_module.Import (mod);
548+ }
549+ }
550+ return runtime_module;
525551}
526552
527553TVM_REGISTER_GLOBAL (" driver.tir_to_runtime" )
@@ -562,18 +588,20 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
562588 return TIRToRuntime (inputs, target_host);
563589}
564590
565- transform::Sequential MixedModulePassManager (IRModule mixed_mod, Target target) {
591+ transform::Sequential MixedModulePassManager (IRModule mixed_mod, Optional< Target> target) {
566592 transform::PassContext pass_ctx = transform::PassContext::Current ();
567593
568594 Array<Pass> mixed_pass_list;
569595
596+ if (target) {
597+ mixed_pass_list.push_back (tir::transform::BindTarget (target.value ()));
598+ }
599+
570600 // VerifyVTCMLimit must occur before LowerVtcmAlloc
571601 mixed_pass_list.push_back (tir::transform::VerifyVTCMLimit (target));
572602 // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
573603 mixed_pass_list.push_back (tir::transform::LowerVtcmAlloc ());
574604
575- mixed_pass_list.push_back (tir::transform::BindTarget (target));
576-
577605 mixed_pass_list.push_back (tir::transform::VerifyMemory ());
578606
579607 mixed_pass_list.push_back (tir::transform::AnnotateEntryFunc ());
@@ -619,7 +647,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
619647
620648 mixed_pass_list.push_back (tir::transform::LowerDeviceKernelLaunch ());
621649
622- return transform::Sequential (mixed_pass_list);
650+ // Only applies to the device functions, identified by inspection of
651+ // each function's tvm::attr::kTarget attribute.
652+ mixed_pass_list.push_back (tir::transform::LowerWarpMemory ());
653+
654+ // Only applies to the host functions, identified by inspection of
655+ // each function's tvm::attr::kTarget attribute.
656+ mixed_pass_list.push_back (tir::transform::LowerTVMBuiltin ());
657+
658+ // Apply to both host and device functions
659+ mixed_pass_list.push_back (tir::transform::Simplify ());
660+ mixed_pass_list.push_back (tir::transform::LowerCustomDatatypes ());
661+ mixed_pass_list.push_back (tir::transform::LowerIntrin ());
662+ mixed_pass_list.push_back (tir::transform::LowerDeviceStorageAccessInfo ());
663+
664+ // Only applies to the host functions, identified by inspection of
665+ // each function's tvm::attr::kTarget attribute.
666+ mixed_pass_list.push_back (tir::transform::CombineContextCall ());
667+ if (pass_ctx->GetConfig <Bool>(" tir.enable_debug" , Bool (false )).value ()) {
668+ mixed_pass_list.push_back (tir::transform::InstallDebugSpans ());
669+ }
670+
671+ return transform::Sequential (mixed_pass_list, " tvm.build" );
623672}
624673
625674TVM_REGISTER_GLOBAL (" driver.mixed_mod_passes" )
@@ -628,6 +677,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
628677 });
629678
630679transform::Sequential HostModulePassManager (IRModule mixed_mod, Target target_host) {
680+ LOG (WARNING) << " Use of driver.host_mod_passes is deprecated. "
681+ << " All lowering passes are now included "
682+ << " as part of driver.mixed_mod_passes." ;
683+
631684 transform::PassContext pass_ctx = transform::PassContext::Current ();
632685 bool enable_debug = pass_ctx->GetConfig <Bool>(" tir.enable_debug" , Bool (false )).value ();
633686
@@ -653,7 +706,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
653706 host_pass_list.push_back (tir::transform::InstallDebugSpans ());
654707 }
655708
656- return transform::Sequential (host_pass_list);
709+ return transform::Sequential (host_pass_list, " tir.host_mod_passes " );
657710}
658711
659712TVM_REGISTER_GLOBAL (" driver.host_mod_passes" )
@@ -662,6 +715,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes")
662715 });
663716
664717transform::Sequential DeviceModulePassManager (IRModule mixed_mod, Target target) {
718+ LOG (WARNING) << " Use of driver.device_mod_passes is deprecated. "
719+ << " All lowering passes are now included "
720+ << " as part of driver.mixed_mod_passes." ;
721+
665722 Array<Pass> device_pass_list;
666723 runtime::TypedPackedFunc<bool (tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
667724 return f->GetAttr <Integer>(tvm::attr::kCallingConv , Integer (CallingConv::kDefault )) ==
@@ -677,7 +734,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target)
677734 device_pass_list.push_back (tir::transform::LowerDeviceStorageAccessInfo ());
678735 device_pass_list.push_back (tir::transform::LowerIntrin ());
679736
680- return transform::Sequential (device_pass_list);
737+ return transform::Sequential (device_pass_list, " tir.device_mod_passes " );
681738}
682739
683740TVM_REGISTER_GLOBAL (" driver.device_mod_passes" )
0 commit comments