Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions testing/python/cache/test_tilelang_kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def setup_module_env():
# Save original env values
original_cache_dir = env.TILELANG_CACHE_DIR
original_tmp_dir = env.TILELANG_TMP_DIR
original_use_gemm_v1 = env.TILELANG_USE_GEMM_V1

# Enable cache once for entire module
tilelang.enable_cache()
Expand All @@ -79,7 +78,6 @@ def setup_module_env():
# Restore env at module end
env.TILELANG_CACHE_DIR = original_cache_dir
env.TILELANG_TMP_DIR = original_tmp_dir
env.TILELANG_USE_GEMM_V1 = original_use_gemm_v1

# Restore default postproc callbacks
tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=lambda code, _: code, override=True)
Expand Down Expand Up @@ -113,7 +111,6 @@ def clean_cache_env(tmp_path, request):
# Patch env variables to point to isolated directories
env.TILELANG_CACHE_DIR = str(cache_dir)
env.TILELANG_TMP_DIR = str(tmp_dir)
env.TILELANG_USE_GEMM_V1 = "1" if backend == "cutedsl" else "0"

# Clear memory caches to force disk I/O
_dispatch_map[backend]._memory_cache.clear()
Expand Down
11 changes: 0 additions & 11 deletions testing/python/jit/test_tilelang_jit_cutedsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,7 @@
import tilelang.testing
import tilelang
import torch
import pytest
from tilelang.utils.tensor import map_torch_type
from tilelang.env import env


@pytest.fixture(scope="module", autouse=True)
def restore_env():
"""Save and restore env settings for this test module"""
original_value = env.TILELANG_USE_GEMM_V1
env.TILELANG_USE_GEMM_V1 = "1"
yield
env.TILELANG_USE_GEMM_V1 = original_value


def matmul(
Expand Down
21 changes: 0 additions & 21 deletions tilelang/jit/execution_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from tvm.target import Target
from tilelang.jit.adapter.utils import is_cutedsl_target
from tilelang.env import env as _env

# Canonical names for execution backends used internally
_CANONICAL_MAP = {
Expand Down Expand Up @@ -76,25 +75,9 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
allowed_all = allowed_backends_for_target(target, include_unavailable=True)
allowed_avail = allowed_backends_for_target(target, include_unavailable=False)

def _require_gemm_v1_for_cutedsl():
if not _env.use_gemm_v1():
raise ValueError(
"CuTeDSL backend requires GEMM v1. Please set environment variable TILELANG_USE_GEMM_V1=1 before importing tilelang."
)
# Fail fast with a clear error if CuTeDSL dependencies are missing or incompatible.
try:
from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy

check_cutedsl_available()
except ImportError as e:
# Keep resolve_execution_backend's error semantics (ValueError) while
# preserving the actionable ImportError message.
raise ValueError(str(e)) from e

# Default selection for auto/None
if req in (None, "auto"):
if is_cutedsl_target(target):
_require_gemm_v1_for_cutedsl()
return "cutedsl"
kind = _target_kind(target)
if kind == "cuda":
Expand Down Expand Up @@ -122,8 +105,4 @@ def _require_gemm_v1_for_cutedsl():
f"Try one of: {_format_options(allowed_avail)}."
)

# CuTeDSL requires GEMM v1
if req == "cutedsl":
_require_gemm_v1_for_cutedsl()

return req
7 changes: 7 additions & 0 deletions tilelang/tileop/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from .gemm_wgmma import GemmWGMMA
from .gemm_tcgen05 import GemmTCGEN5
from .gemm_mfma import GemmMFMA
from .gemm_cutedsl import GemmCuTeDSL
from tilelang import _ffi_api
from tilelang.utils.target import target_is_volta
from tilelang.jit.adapter.utils import is_cutedsl_target


@tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
Expand Down Expand Up @@ -168,6 +170,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):

Args:
gemm_inst: The selected GEMM instruction type
target: Target architecture

Returns:
The implementation class for the instruction type
Expand All @@ -176,6 +179,10 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):
NotImplementedError: If the instruction type is not supported
ValueError: If the instruction type is unknown
"""
# CuTeDSL backend uses direct intrinsic call, bypass complex lowering
if is_cutedsl_target(target):
return GemmCuTeDSL

if gemm_inst.is_mma():
if target_is_volta(target):
return GemmMMASm70
Expand Down
63 changes: 63 additions & 0 deletions tilelang/tileop/gemm/gemm_cutedsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""GEMM implementation for CuTeDSL backend - directly calls tl::gemm intrinsic."""

from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang import language as T
from tvm import tir
from tvm.target import Target


class GemmCuTeDSL(GemmBase):
"""GEMM implementation for CuTeDSL that directly calls tl::gemm intrinsic.

This implementation bypasses the complex lowering logic of MMA/WGMMA
and directly emits a call to tl::gemm, similar to gemm_v1 behavior.
This is necessary for CuTeDSL backend which requires simpler IR.
"""

def infer_layout(self, target: Target, thread_nums: int):
"""For CuTeDSL, we still need proper layout inference for A, B, C buffers.

CuTeDSL uses the same underlying hardware instructions (WGMMA/MMA),
so it needs the same layout information. We delegate to the appropriate
implementation based on the instruction type.
"""
from tilelang.tileop.gemm import GemmInst
from tilelang.tileop.gemm.gemm_wgmma import GemmWGMMA
from tilelang.tileop.gemm.gemm_mma import GemmMMA
from tilelang import _ffi_api

# Determine which GEMM instruction will be used
gemm_inst = GemmInst(_ffi_api.GemmPyGemmInst(self.gemm_node, int(thread_nums), target))

# Use WGMMA or MMA layout inference based on instruction type
if gemm_inst.is_wgmma():
return GemmWGMMA(self.gemm_node).infer_layout(target, thread_nums)
else:
return GemmMMA(self.gemm_node).infer_layout(target, thread_nums)

def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
"""Lower to a direct gemm_v1 call without complex MMA/WGMMA lowering."""
from tilelang.language.gemm_op import gemm_v1
from tilelang.transform.simplify import _Simplify
from tilelang.tileop.base import GemmWarpPolicy as PyGemmWarpPolicy

# Convert C++ GemmWarpPolicy to Python enum value (int)
policy_int = self.policy.policy_type

@T.prim_func
def _gemm_cutedsl() -> None:
gemm_v1(
self.A,
self.B,
self.C,
self.trans_A,
self.trans_B,
PyGemmWarpPolicy(policy_int),
self.clear_accum,
self.k_pack,
self.wg_wait,
self.mbar,
)

# Simplify and return
return _Simplify(_gemm_cutedsl, inline_let=True)
72 changes: 49 additions & 23 deletions tilelang/utils/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ def check_metal_availability() -> bool:
return arch == "arm64"


def normalize_cutedsl_target(target: str | Target) -> Target | None:
if isinstance(target, Target):
if target.kind.name == "cuda" and "cutedsl" in target.keys:
return target
return None

if target.startswith("cutedsl"):
cuda_target_str = target.replace("cutedsl", "cuda", 1)

try:
temp_target = Target(cuda_target_str)

target_dict = dict(temp_target.export())
target_dict["keys"] = list(set(target_dict["keys"]) | {"cutedsl"})

return Target(target_dict)
except Exception:
return None

return None


def determine_target(target: str | Target | Literal["auto"] = "auto", return_object: bool = False) -> str | Target:
"""
Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
Expand Down Expand Up @@ -96,33 +118,37 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj
return_var = "metal"
else:
raise ValueError("No CUDA or HIP or MPS available on this system.")
elif isinstance(target, str) and target.startswith("cutedsl"):
cuda_target_str = target.replace("cutedsl", "cuda", 1)
temp_target = Target(cuda_target_str)

target_dict = dict(temp_target.export())
target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"]

return_var = Target(target_dict)
else:
# Validate the target if it's not "auto"
if isinstance(target, Target):
return_var = target
elif isinstance(target, str):
normalized_target = target.strip()
if not normalized_target:
raise AssertionError(f"Target {target} is not supported")
possible_cutedsl_target = normalize_cutedsl_target(target)
if possible_cutedsl_target is not None:
try:
Target(normalized_target)
except Exception as err:
examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS)
raise AssertionError(
f"Target {target} is not supported. Supported targets include: {examples}. "
"Pass additional options after the base name, e.g. `cuda -arch=sm_80`."
) from err
return_var = normalized_target
from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy

check_cutedsl_available()
except ImportError as e:
raise AssertionError(f"CuTeDSL backend is not available. Please install tilelang-cutedsl package. {str(e)}") from e

return_var = possible_cutedsl_target
else:
raise AssertionError(f"Target {target} is not supported")
# Validate the target if it's not "auto"
if isinstance(target, Target):
return_var = target
elif isinstance(target, str):
normalized_target = target.strip()
if not normalized_target:
raise AssertionError(f"Target {target} is not supported")
try:
Target(normalized_target)
except Exception as err:
examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS)
raise AssertionError(
f"Target {target} is not supported. Supported targets include: {examples}. "
"Pass additional options after the base name, e.g. `cuda -arch=sm_80`."
) from err
return_var = normalized_target
else:
raise AssertionError(f"Target {target} is not supported")

if isinstance(return_var, Target):
return return_var
Expand Down
Loading