Skip to content

Commit d5dc496

Browse files
author
ibsidorenko
committed
Revert changes of GetPassPrefix interface.
1 parent c1b6399 commit d5dc496

File tree

5 files changed

+6
-8
lines changed

5 files changed

+6
-8
lines changed

src/relay/backend/build_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ class RelayBuildModule : public runtime::ModuleNode {
328328
backend::BindParamsInModule(relay_module, params_);
329329

330330
Array<Pass> pass_seqs =
331-
GetPassPrefix(/*homogeneous target=*/config_->optional_homogeneous_target, /*is_vm=*/false);
331+
GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false);
332332
transform::PassContext pass_ctx = PassContext::Current();
333333

334334
if (config_->optional_homogeneous_target.defined()) {

src/relay/backend/task_extraction.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
3636
backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
3737
backend::BindParamsInModule(mod, params);
3838
// is_vm=true for backward compatibility
39-
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(target, /*is_vm=*/true);
39+
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
4040
pass_seqs.push_back(transform::FuseOps());
4141

4242
mod = transform::Sequential(pass_seqs)(std::move(mod));

src/relay/backend/utils.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,13 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata(
219219

220220
TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode);
221221

222-
Array<Pass> GetPassPrefix(Target homogeneous_target, bool is_vm) {
222+
Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm) {
223223
Array<Pass> pass_seqs;
224224
// TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton
225225
// by most passes there's little utility in including this now. Plus we'd need to only do
226226
// this if there's no existing spans to work from.
227227
// pass_seqs.push_back(parser::AnnotateSpans());
228228
Array<runtime::String> entry_functions{"main"};
229-
// Can be undefined in case of heterogeneous execution
230-
bool is_homogeneous = homogeneous_target.defined();
231229
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
232230
pass_seqs.push_back(transform::ToBasicBlockNormalForm());
233231
// Run all dialect legalization passes.

src/relay/backend/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,11 @@ inline IRModule PrimFuncToIRModule(tir::PrimFunc f) {
676676
* difference. This function unifies the shared optimization pass prefix between vm and graph
677677
* runtime, and returns the pass prefix given the backend type.
678678
*
679-
* \param homogeneous_target Execution target (can be undefined in case of heterogeneous execution).
679+
* \param is_homogeneous True if all primitives are to be executed on the same device and target.
680680
* \param is_vm True if passes are to be used for the vm executor.
681681
* \return An array of passes.
682682
*/
683-
Array<Pass> GetPassPrefix(Target homogeneous_target, bool is_vm);
683+
Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm);
684684

685685
/*! \brief Target hash function */
686686
struct TargetStrHash {

src/relay/backend/vm/compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig&
10541054
IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
10551055
backend::BindParamsInModule(mod, params_);
10561056
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(
1057-
/*homogeneous target=*/config_->optional_homogeneous_target, /*is_vm=*/true);
1057+
/*is_homogeneous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true);
10581058

10591059
// Always plan devices so the remaining passes don't need to distinguish homogeneous vs
10601060
// heterogeneous execution.

0 commit comments

Comments
 (0)