From 3a7b15a6ac7d62abf7ae7a44cd8f433ec29ebd98 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Apr 2024 10:36:24 -0500 Subject: [PATCH] [DLight] Check for target in function attributes Prior to this commit, the `dlight` scheduling rules were applied solely based on the global `tvm.target.Target.current()`. However, a TIR PrimFunc may be annotated with the target, rather than using the global `Target.current()`. In this case, the `dlight` scheduling may produce a scheduled PrimFunc that is not compatible with its target. For example, using a thread binding to `"threadIdx.x"` on a CPU target. This commit updates `dlight` to check for a TIR PrimFunc's annotations when scheduling, matching the behavior of `tvm.build`. --- python/tvm/dlight/base/transform.py | 11 +++- tests/python/dlight/test_gpu_fallback.py | 78 ++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py index d697e9440b31..0f2895164d5b 100644 --- a/python/tvm/dlight/base/transform.py +++ b/python/tvm/dlight/base/transform.py @@ -36,6 +36,14 @@ def _is_scheduled(func: tir.PrimFunc) -> bool: return func.attrs["tir.is_scheduled"] == 1 +def _get_target(func: tir.PrimFunc) -> Target: + target = func.attrs.get("target") + if target is None: + return Target.current(allow_none=False) + else: + return target + + @module_pass(opt_level=0, name="ApplyDefaultSchedule") class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" @@ -55,10 +63,11 @@ def transform_module( # pylint: disable=missing-function-docstring mod: IRModule, _: PassContext, ) -> IRModule: - target = Target.current(allow_none=False) updated_functions = {} for g_var, func in mod.functions_items(): if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): + target = _get_target(func) + sch = _apply_rules(func, target, self.rules, tunable=False) if sch is not None: assert len(sch) == 1 diff --git a/tests/python/dlight/test_gpu_fallback.py b/tests/python/dlight/test_gpu_fallback.py index 4457e627bd58..43fac3ad4148 100644 --- a/tests/python/dlight/test_gpu_fallback.py +++ b/tests/python/dlight/test_gpu_fallback.py @@ -179,5 +179,83 @@ def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_tabl assert_structural_equal(mod["main"], expected) +def test_gpu_fallback_ignores_non_gpu_functions(): + @I.ir_module + class Before: + # This function has no "target" attribute, and is scheduled + # using the `Target.current`. + @T.prim_func + def gpu_func( + A: T.Buffer((1, 32, 1, 128), "float16"), + C: T.Buffer((1, 1, 4096), "float16"), + ): + B = T.alloc_buffer((1, 1, 32, 128), "float16") + for i, j, k, l in T.grid(1, 1, 32, 128): + with T.block("T_transpose"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vk, vj, vl] + for i, j, k in T.grid(1, 1, 4096): + with T.block("T_reshape"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] + + # This function is identical, except that it is explicitly + # annotated with the "target" attribute, and is scheduled + # based on the annotation's target. + @T.prim_func + def cpu_func( + A: T.Buffer((1, 32, 1, 128), "float16"), + C: T.Buffer((1, 1, 4096), "float16"), + ): + T.func_attr({"target": T.target("llvm")}) + B = T.alloc_buffer((1, 1, 32, 128), "float16") + for i, j, k, l in T.grid(1, 1, 32, 128): + with T.block("T_transpose"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vk, vj, vl] + for i, j, k in T.grid(1, 1, 4096): + with T.block("T_reshape"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] + + @I.ir_module + class After: + @T.prim_func + def gpu_func( + A: T.Buffer((1, 32, 1, 128), "float16"), + C: T.Buffer((1, 1, 4096), "float16"), + ): + T.func_attr({"tir.is_scheduled": 1}) + for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1) + T.reads(A[0, v0 // 128, 0, v0 % 128]) + T.writes(C[0, 0, v0]) + C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128] + + @T.prim_func + def cpu_func( + A: T.Buffer((1, 32, 1, 128), "float16"), + C: T.Buffer((1, 1, 4096), "float16"), + ): + T.func_attr({"target": T.target("llvm")}) + B = T.alloc_buffer((1, 1, 32, 128), "float16") + for i, j, k, l in T.grid(1, 1, 32, 128): + with T.block("T_transpose"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vk, vj, vl] + for i, j, k in T.grid(1, 1, 4096): + with T.block("T_reshape"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] + + with Target("cuda"): + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Fallback(), + )(Before) + assert_structural_equal(mod, After) + + if __name__ == "__main__": tvm.testing.main()