From 3cb96de8aef714c58e3fd6a6d9df795a2b18046a Mon Sep 17 00:00:00 2001 From: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Date: Fri, 26 Dec 2025 05:57:30 -0800 Subject: [PATCH 1/2] feat: adapt gemm v2 for cutedsl backend --- .../cache/test_tilelang_kernel_cache.py | 3 - .../python/jit/test_tilelang_jit_cutedsl.py | 11 --- tilelang/jit/execution_backend.py | 21 ------ tilelang/tileop/gemm/__init__.py | 7 ++ tilelang/tileop/gemm/gemm_cutedsl.py | 52 ++++++++++++++ tilelang/utils/target.py | 72 +++++++++++++------ 6 files changed, 108 insertions(+), 58 deletions(-) create mode 100644 tilelang/tileop/gemm/gemm_cutedsl.py diff --git a/testing/python/cache/test_tilelang_kernel_cache.py b/testing/python/cache/test_tilelang_kernel_cache.py index 0324f80ec..9f6683a8d 100644 --- a/testing/python/cache/test_tilelang_kernel_cache.py +++ b/testing/python/cache/test_tilelang_kernel_cache.py @@ -69,7 +69,6 @@ def setup_module_env(): # Save original env values original_cache_dir = env.TILELANG_CACHE_DIR original_tmp_dir = env.TILELANG_TMP_DIR - original_use_gemm_v1 = env.TILELANG_USE_GEMM_V1 # Enable cache once for entire module tilelang.enable_cache() @@ -79,7 +78,6 @@ def setup_module_env(): # Restore env at module end env.TILELANG_CACHE_DIR = original_cache_dir env.TILELANG_TMP_DIR = original_tmp_dir - env.TILELANG_USE_GEMM_V1 = original_use_gemm_v1 # Restore default postproc callbacks tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=lambda code, _: code, override=True) @@ -113,7 +111,6 @@ def clean_cache_env(tmp_path, request): # Patch env variables to point to isolated directories env.TILELANG_CACHE_DIR = str(cache_dir) env.TILELANG_TMP_DIR = str(tmp_dir) - env.TILELANG_USE_GEMM_V1 = "1" if backend == "cutedsl" else "0" # Clear memory caches to force disk I/O _dispatch_map[backend]._memory_cache.clear() diff --git a/testing/python/jit/test_tilelang_jit_cutedsl.py b/testing/python/jit/test_tilelang_jit_cutedsl.py index 564f99eb6..202bbf117 100644 --- a/testing/python/jit/test_tilelang_jit_cutedsl.py +++ b/testing/python/jit/test_tilelang_jit_cutedsl.py @@ -3,18 +3,7 @@ import tilelang.testing import tilelang import torch -import pytest from tilelang.utils.tensor import map_torch_type -from tilelang.env import env - - -@pytest.fixture(scope="module", autouse=True) -def restore_env(): - """Save and restore env settings for this test module""" - original_value = env.TILELANG_USE_GEMM_V1 - env.TILELANG_USE_GEMM_V1 = "1" - yield - env.TILELANG_USE_GEMM_V1 = original_value def matmul( diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py index e2604f24b..fa3c1ecb0 100644 --- a/tilelang/jit/execution_backend.py +++ b/tilelang/jit/execution_backend.py @@ -4,7 +4,6 @@ from tvm.target import Target from tilelang.jit.adapter.utils import is_cutedsl_target -from tilelang.env import env as _env # Canonical names for execution backends used internally _CANONICAL_MAP = { @@ -76,25 +75,9 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: allowed_all = allowed_backends_for_target(target, include_unavailable=True) allowed_avail = allowed_backends_for_target(target, include_unavailable=False) - def _require_gemm_v1_for_cutedsl(): - if not _env.use_gemm_v1(): - raise ValueError( - "CuTeDSL backend requires GEMM v1. Please set environment variable TILELANG_USE_GEMM_V1=1 before importing tilelang." - ) - # Fail fast with a clear error if CuTeDSL dependencies are missing or incompatible. - try: - from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy - - check_cutedsl_available() - except ImportError as e: - # Keep resolve_execution_backend's error semantics (ValueError) while - # preserving the actionable ImportError message. - raise ValueError(str(e)) from e - # Default selection for auto/None if req in (None, "auto"): if is_cutedsl_target(target): - _require_gemm_v1_for_cutedsl() return "cutedsl" kind = _target_kind(target) if kind == "cuda": @@ -122,8 +105,4 @@ def _require_gemm_v1_for_cutedsl(): f"Try one of: {_format_options(allowed_avail)}." ) - # CuTeDSL requires GEMM v1 - if req == "cutedsl": - _require_gemm_v1_for_cutedsl() - return req diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 850b6f3b0..bdb1ac0c6 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -11,8 +11,10 @@ from .gemm_wgmma import GemmWGMMA from .gemm_tcgen05 import GemmTCGEN5 from .gemm_mfma import GemmMFMA +from .gemm_cutedsl import GemmCuTeDSL from tilelang import _ffi_api from tilelang.utils.target import target_is_volta +from tilelang.jit.adapter.utils import is_cutedsl_target @tvm_ffi.register_global_func("tl.gemm_py.infer_layout") @@ -168,6 +170,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): Args: gemm_inst: The selected GEMM instruction type + target: Target architecture Returns: The implementation class for the instruction type @@ -176,6 +179,10 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): NotImplementedError: If the instruction type is not supported ValueError: If the instruction type is unknown """ + # CuTeDSL backend uses direct intrinsic call, bypass complex lowering + if is_cutedsl_target(target): + return GemmCuTeDSL + if gemm_inst.is_mma(): if target_is_volta(target): return GemmMMASm70 diff --git a/tilelang/tileop/gemm/gemm_cutedsl.py b/tilelang/tileop/gemm/gemm_cutedsl.py new file mode 100644 index 000000000..40da798e7 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_cutedsl.py @@ -0,0 +1,52 @@ +"""GEMM implementation for CuTeDSL backend - directly calls tl::gemm intrinsic.""" + +from tilelang.tileop.gemm.gemm_base import GemmBase +from tilelang import language as T +from tvm import tir +from tvm.target import Target + + +class GemmCuTeDSL(GemmBase): + """GEMM implementation for CuTeDSL that directly calls tl::gemm intrinsic. + + This implementation bypasses the complex lowering logic of MMA/WGMMA + and directly emits a call to tl::gemm, similar to gemm_v1 behavior. + This is necessary for CuTeDSL backend which requires simpler IR. + """ + + def infer_layout(self, target: Target, thread_nums: int): + """For CuTeDSL, we still need proper layout inference for A, B, C buffers. + + CuTeDSL uses the same underlying hardware instructions (WGMMA/MMA), + so it needs the same layout information. We delegate to the appropriate + implementation based on the instruction type. + """ + from tilelang.tileop.gemm import GemmInst + from tilelang.tileop.gemm.gemm_wgmma import GemmWGMMA + from tilelang.tileop.gemm.gemm_mma import GemmMMA + from tilelang import _ffi_api + + # Determine which GEMM instruction will be used + gemm_inst = GemmInst(_ffi_api.GemmPyGemmInst(self.gemm_node, int(thread_nums), target)) + + # Use WGMMA or MMA layout inference based on instruction type + if gemm_inst.is_wgmma(): + return GemmWGMMA(self.gemm_node).infer_layout(target, thread_nums) + else: + return GemmMMA(self.gemm_node).infer_layout(target, thread_nums) + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + """Lower to a direct gemm_v1 call without complex MMA/WGMMA lowering.""" + from tilelang.language.gemm_op import gemm_v1 + from tilelang.transform.simplify import _Simplify + from tilelang.tileop.base import GemmWarpPolicy as PyGemmWarpPolicy + + # Convert C++ GemmWarpPolicy to Python enum value (int) + policy_int = self.policy.policy_type + + @T.prim_func + def _gemm_cutedsl() -> None: + gemm_v1(self.A, self.B, self.C, self.trans_A, self.trans_B, PyGemmWarpPolicy(policy_int), self.clear_accum, self.k_pack, self.wg_wait, self.mbar) + + # Simplify and return + return _Simplify(_gemm_cutedsl, inline_let=True) diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index a2b88f5e8..72077c26c 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -60,6 +60,28 @@ def check_metal_availability() -> bool: return arch == "arm64" +def normalize_cutedsl_target(target: str | Target) -> Target | None: + if isinstance(target, Target): + if target.kind.name == "cuda" and "cutedsl" in target.keys: + return target + return None + + if target.startswith("cutedsl"): + cuda_target_str = target.replace("cutedsl", "cuda", 1) + + try: + temp_target = Target(cuda_target_str) + + target_dict = dict(temp_target.export()) + target_dict["keys"] = list(set(target_dict["keys"]) | {"cutedsl"}) + + return Target(target_dict) + except Exception: + return None + + return None + + def determine_target(target: str | Target | Literal["auto"] = "auto", return_object: bool = False) -> str | Target: """ Determine the appropriate target for compilation (CUDA, HIP, or manual selection). @@ -96,33 +118,37 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj return_var = "metal" else: raise ValueError("No CUDA or HIP or MPS available on this system.") - elif isinstance(target, str) and target.startswith("cutedsl"): - cuda_target_str = target.replace("cutedsl", "cuda", 1) - temp_target = Target(cuda_target_str) - target_dict = dict(temp_target.export()) - target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"] - - return_var = Target(target_dict) else: - # Validate the target if it's not "auto" - if isinstance(target, Target): - return_var = target - elif isinstance(target, str): - normalized_target = target.strip() - if not normalized_target: - raise AssertionError(f"Target {target} is not supported") + possible_cutedsl_target = normalize_cutedsl_target(target) + if possible_cutedsl_target is not None: try: - Target(normalized_target) - except Exception as err: - examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS) - raise AssertionError( - f"Target {target} is not supported. Supported targets include: {examples}. " - "Pass additional options after the base name, e.g. `cuda -arch=sm_80`." - ) from err - return_var = normalized_target + from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy + + check_cutedsl_available() + except ImportError as e: + raise AssertionError(f"CuTeDSL backend is not available. Please install tilelang-cutedsl package. {str(e)}") from e + + return_var = possible_cutedsl_target else: - raise AssertionError(f"Target {target} is not supported") + # Validate the target if it's not "auto" + if isinstance(target, Target): + return_var = target + elif isinstance(target, str): + normalized_target = target.strip() + if not normalized_target: + raise AssertionError(f"Target {target} is not supported") + try: + Target(normalized_target) + except Exception as err: + examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS) + raise AssertionError( + f"Target {target} is not supported. Supported targets include: {examples}. " + "Pass additional options after the base name, e.g. `cuda -arch=sm_80`." + ) from err + return_var = normalized_target + else: + raise AssertionError(f"Target {target} is not supported") if isinstance(return_var, Target): return return_var From 28dcb30c5e33c06a4962c50325507ef43fc7ba2f Mon Sep 17 00:00:00 2001 From: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Date: Fri, 26 Dec 2025 05:59:31 -0800 Subject: [PATCH 2/2] fix: ruff --- tilelang/tileop/gemm/gemm_cutedsl.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tilelang/tileop/gemm/gemm_cutedsl.py b/tilelang/tileop/gemm/gemm_cutedsl.py index 40da798e7..1c6d4488c 100644 --- a/tilelang/tileop/gemm/gemm_cutedsl.py +++ b/tilelang/tileop/gemm/gemm_cutedsl.py @@ -8,7 +8,7 @@ class GemmCuTeDSL(GemmBase): """GEMM implementation for CuTeDSL that directly calls tl::gemm intrinsic. - + This implementation bypasses the complex lowering logic of MMA/WGMMA and directly emits a call to tl::gemm, similar to gemm_v1 behavior. This is necessary for CuTeDSL backend which requires simpler IR. @@ -16,7 +16,7 @@ class GemmCuTeDSL(GemmBase): def infer_layout(self, target: Target, thread_nums: int): """For CuTeDSL, we still need proper layout inference for A, B, C buffers. - + CuTeDSL uses the same underlying hardware instructions (WGMMA/MMA), so it needs the same layout information. We delegate to the appropriate implementation based on the instruction type. @@ -25,10 +25,10 @@ def infer_layout(self, target: Target, thread_nums: int): from tilelang.tileop.gemm.gemm_wgmma import GemmWGMMA from tilelang.tileop.gemm.gemm_mma import GemmMMA from tilelang import _ffi_api - + # Determine which GEMM instruction will be used gemm_inst = GemmInst(_ffi_api.GemmPyGemmInst(self.gemm_node, int(thread_nums), target)) - + # Use WGMMA or MMA layout inference based on instruction type if gemm_inst.is_wgmma(): return GemmWGMMA(self.gemm_node).infer_layout(target, thread_nums) @@ -40,13 +40,24 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: from tilelang.language.gemm_op import gemm_v1 from tilelang.transform.simplify import _Simplify from tilelang.tileop.base import GemmWarpPolicy as PyGemmWarpPolicy - + # Convert C++ GemmWarpPolicy to Python enum value (int) policy_int = self.policy.policy_type @T.prim_func def _gemm_cutedsl() -> None: - gemm_v1(self.A, self.B, self.C, self.trans_A, self.trans_B, PyGemmWarpPolicy(policy_int), self.clear_accum, self.k_pack, self.wg_wait, self.mbar) - + gemm_v1( + self.A, + self.B, + self.C, + self.trans_A, + self.trans_B, + PyGemmWarpPolicy(policy_int), + self.clear_accum, + self.k_pack, + self.wg_wait, + self.mbar, + ) + # Simplify and return return _Simplify(_gemm_cutedsl, inline_let=True)