Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def f_zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm
seq = tvm.transform.Sequential(
[
transform.LegalizeOps(enable_warning=enable_warning),
tvm.tir.transform.DefaultGPUSchedule(),
transform.AnnotateTIROpPattern(),
transform.FoldConstant(),
transform.FuseOps(),
Expand All @@ -84,6 +85,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
backend.DispatchSampling(),
backend.DispatchSortScan(),
transform.LegalizeOps(),
tvm.tir.transform.DefaultGPUSchedule(),
transform.RewriteDataflowReshape(),
transform.ToNonDataflow(),
transform.RemovePurityChecking(),
Expand Down
8 changes: 5 additions & 3 deletions tests/python/ir/test_pass_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Instrument test cases.
"""

""" Instrument test cases. """

import tvm

from tvm import relax
from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll
from tvm.script import ir as I
Expand Down Expand Up @@ -56,7 +57,8 @@ def func(x: R.Tensor((16,), "float32"), y: R.Tensor((16,), "float32")):

pipeline = relax.get_pipeline("default_build")
with tvm.transform.PassContext(opt_level=3, instruments=[PrintBeforeAll(), PrintAfterAll()]):
pipeline(Module)
with tvm.target.Target("llvm"):
pipeline(Module)
all_passes_output = capsys.readouterr().out
assert "Before Running Pass:" in all_passes_output
assert "After Running Pass:" in all_passes_output
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relax/test_vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,5 +1309,25 @@ def func_llvm(
tvm.testing.assert_allclose(cuda_output.numpy(), np_C)


@tvm.testing.requires_cuda(support_required="compile-only")
def test_relax_default_gpu_scheduling():
"""Legalization may produce operations that require scheduling

The default build pipeline legalizes any Relax operators that are
not already legalized. This may produce TIR PrimFuncs that must
be scheduled for use on the GPU.

"""

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.add(A, B)
return C

tvm.relax.build(Module, target="cuda")


if __name__ == "__main__":
tvm.testing.main()
Loading