Skip to content

Commit 6440310

Browse files
committed
[Relax] Apply DefaultGPUSchedule() in default build pipeline
This is a follow-up to #15864, which added `LegalizeOps` to the default Relax build pipeline. Since legalization may produce additional TIR PrimFuncs that require scheduling, the output of `LegalizeOps` typically must also be passed through `tir.transform.DefaultGPUSchedule()`. This PR adds `DefaultGPUSchedule()` to the relax build pipeline to handle these cases. Scheduled PrimFunc have the `"tir.is_scheduled"` attribute set to true, and are ignored by `DefaultGPUSchedule()`. In addition, the `DefaultGPUSchedule` transform has no effect on non-GPU targets. Therefore, this change should only impact `tvm.relax.build` calls that previously resulted in an error due to unscheduled GPU functions, and should not have any impact on existing calls.
1 parent 72b75fe commit 6440310

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

python/tvm/relax/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def f_zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm
6060
seq = tvm.transform.Sequential(
6161
[
6262
transform.LegalizeOps(enable_warning=enable_warning),
63+
tvm.tir.transform.DefaultGPUSchedule(),
6364
transform.AnnotateTIROpPattern(),
6465
transform.FoldConstant(),
6566
transform.FuseOps(),
@@ -84,6 +85,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
8485
backend.DispatchSampling(),
8586
backend.DispatchSortScan(),
8687
transform.LegalizeOps(),
88+
tvm.tir.transform.DefaultGPUSchedule(),
8789
transform.RewriteDataflowReshape(),
8890
transform.ToNonDataflow(),
8991
transform.RemovePurityChecking(),

tests/python/ir/test_pass_instrument.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
""" Instrument test cases.
18-
"""
17+
18+
""" Instrument test cases. """
1919

2020
import tvm
21+
2122
from tvm import relax
2223
from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll
2324
from tvm.script import ir as I
@@ -56,7 +57,8 @@ def func(x: R.Tensor((16,), "float32"), y: R.Tensor((16,), "float32")):
5657

5758
pipeline = relax.get_pipeline("default_build")
5859
with tvm.transform.PassContext(opt_level=3, instruments=[PrintBeforeAll(), PrintAfterAll()]):
59-
pipeline(Module)
60+
with tvm.target.Target("llvm"):
61+
pipeline(Module)
6062
all_passes_output = capsys.readouterr().out
6163
assert "Before Running Pass:" in all_passes_output
6264
assert "After Running Pass:" in all_passes_output

tests/python/relax/test_vm_build.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,5 +1309,25 @@ def func_llvm(
13091309
tvm.testing.assert_allclose(cuda_output.numpy(), np_C)
13101310

13111311

1312+
@tvm.testing.requires_cuda(support_required="compile-only")
1313+
def test_relax_default_gpu_scheduling():
1314+
"""Legalization may produce operations that require scheduling
1315+
1316+
The default build pipeline legalizes any Relax operators that are
1317+
not already legalized. This may produce TIR PrimFuncs that must
1318+
be scheduled for use on the GPU.
1319+
1320+
"""
1321+
1322+
@I.ir_module
1323+
class Module:
1324+
@R.function
1325+
def main(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
1326+
C = R.add(A, B)
1327+
return C
1328+
1329+
tvm.relax.build(Module, target="cuda")
1330+
1331+
13121332
if __name__ == "__main__":
13131333
tvm.testing.main()

0 commit comments

Comments
 (0)