diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 38242ff4d2d3..2f67318392f4 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -60,6 +60,7 @@ def f_zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm seq = tvm.transform.Sequential( [ transform.LegalizeOps(enable_warning=enable_warning), + tvm.tir.transform.DefaultGPUSchedule(), transform.AnnotateTIROpPattern(), transform.FoldConstant(), transform.FuseOps(), @@ -84,6 +85,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I backend.DispatchSampling(), backend.DispatchSortScan(), transform.LegalizeOps(), + tvm.tir.transform.DefaultGPUSchedule(), transform.RewriteDataflowReshape(), transform.ToNonDataflow(), transform.RemovePurityChecking(), diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index cfeb70b96388..fb265969c035 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" Instrument test cases. -""" + +""" Instrument test cases. """ import tvm + from tvm import relax from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll from tvm.script import ir as I @@ -56,7 +57,8 @@ def func(x: R.Tensor((16,), "float32"), y: R.Tensor((16,), "float32")): pipeline = relax.get_pipeline("default_build") with tvm.transform.PassContext(opt_level=3, instruments=[PrintBeforeAll(), PrintAfterAll()]): - pipeline(Module) + with tvm.target.Target("llvm"): + pipeline(Module) all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index ecf33aa9da1e..bf37ec2183a3 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -1309,5 +1309,25 @@ def func_llvm( tvm.testing.assert_allclose(cuda_output.numpy(), np_C) +@tvm.testing.requires_cuda(support_required="compile-only") +def test_relax_default_gpu_scheduling(): + """Legalization may produce operations that require scheduling + + The default build pipeline legalizes any Relax operators that are + not already legalized. This may produce TIR PrimFuncs that must + be scheduled for use on the GPU. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")): + C = R.add(A, B) + return C + + tvm.relax.build(Module, target="cuda") + + if __name__ == "__main__": tvm.testing.main()