Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,965 changes: 1,965 additions & 0 deletions examples/jitv2/jitv2.ipynb

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@
from .builtin import * # noqa: F401

from .utils import index_to_coordinates # noqa: F401
from .dtypes import (
AnyDType, # noqa: F401
get_tvm_dtype, # noqa: F401
get_torch_dtype, # noqa: F401
)


def symbolic(name: str, dtype: str = "int32"):
def symbolic(name: str, dtype: AnyDType = "int32"):
"""
Create a TIR symbolic variable.

Expand All @@ -93,6 +98,7 @@ def symbolic(name: str, dtype: str = "int32"):
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
dtype = get_tvm_dtype(dtype)
return tir.Var(name, dtype)


Expand Down
24 changes: 16 additions & 8 deletions tilelang/language/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from tilelang import tvm as tvm
from tvm.script import tir as T
from tilelang.language.dtypes import get_tvm_dtype, AnyDType


def alloc_shared(shape, dtype, scope="shared.dyn"):
def alloc_shared(shape, dtype: AnyDType, scope="shared.dyn"):
"""Allocate a shared memory buffer for inter-thread communication.

Args:
Expand All @@ -29,14 +30,15 @@ def alloc_shared(shape, dtype, scope="shared.dyn"):
Returns:
T.Buffer: A TVM buffer object allocated in shared memory
"""
if dtype == "bool":
dtype = get_tvm_dtype(dtype)
if dtype == tvm.DataType("bool"):
# lei: This is a hack to handle bool type.
# Because tilelang's merge smem pass cannot merge bool type currently.
scope = "shared"
return T.alloc_buffer(shape, dtype, scope=scope)


def alloc_local(shape, dtype, scope="local"):
def alloc_local(shape, dtype: AnyDType, scope="local"):
"""Allocate a local memory buffer for thread-private storage.

Args:
Expand All @@ -47,10 +49,11 @@ def alloc_local(shape, dtype, scope="local"):
Returns:
T.Buffer: A TVM buffer object allocated in local memory
"""
dtype = get_tvm_dtype(dtype)
return T.alloc_buffer(shape, dtype, scope=scope)


def alloc_fragment(shape, dtype, scope="local.fragment"):
def alloc_fragment(shape, dtype: AnyDType, scope="local.fragment"):
"""Allocate a fragment memory buffer for specialized operations.

Args:
Expand All @@ -61,10 +64,11 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
Returns:
T.Buffer: A TVM buffer object allocated in fragment memory
"""
dtype = get_tvm_dtype(dtype)
return T.alloc_buffer(shape, dtype, scope=scope)


def alloc_var(dtype, scope="local.var"):
def alloc_var(dtype: AnyDType, scope="local.var"):
"""Allocate a single-element variable buffer.

Args:
Expand All @@ -74,6 +78,7 @@ def alloc_var(dtype, scope="local.var"):
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
dtype = get_tvm_dtype(dtype)
return T.alloc_buffer([1], dtype, scope=scope)


Expand All @@ -89,7 +94,7 @@ def alloc_barrier(arrive_count: int):
return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier")


def alloc_tmem(shape, dtype):
def alloc_tmem(shape, dtype: AnyDType):
"""
Allocate a Tensor Memory (TMEM) buffer for use with 5th generation Tensor Core operations (e.g., TCGEN5.MMA).

Expand All @@ -114,11 +119,12 @@ def alloc_tmem(shape, dtype):
- The buffer returned should be used according to TMEM access restrictions and deallocated appropriately.
"""

dtype = get_tvm_dtype(dtype)
assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation"
return T.alloc_buffer(shape, dtype, scope="shared.tmem")


def alloc_reducer(shape, dtype, op="sum", replication=None):
def alloc_reducer(shape, dtype: AnyDType, op="sum", replication=None):
"""
Allocate a reducer buffer.

Expand Down Expand Up @@ -149,16 +155,18 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
replication = "none"
assert replication in ["all", "none"]

dtype = get_tvm_dtype(dtype)
reducer = T.alloc_buffer(shape, dtype, scope="local.fragment")
TL.block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}})

return reducer


def alloc_descriptor(dtype="uint64", scope="local.descriptor"):
def alloc_descriptor(dtype: AnyDType = "uint64", scope="local.descriptor"):
"""Allocate a descriptor buffer for wgmma and utcmma.

Returns:
T.Buffer: A TVM buffer object allocated as a descriptor
"""
dtype = get_tvm_dtype(dtype)
return T.alloc_buffer([1], dtype, scope=scope)
4 changes: 3 additions & 1 deletion tilelang/language/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tvm.tir import PrimExpr, Buffer, op
from typing import List, Union
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
from .dtypes import get_tvm_dtype, AnyDType


def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
Expand Down Expand Up @@ -51,7 +52,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:

def view(src: Buffer,
shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer:
dtype: Union[AnyDType, None] = None) -> Buffer:
"""
Return a Tensor view of the input buffer with an optional new shape and dtype.

Expand All @@ -61,6 +62,7 @@ def view(src: Buffer,
shape = src.shape
if dtype is None:
dtype = src.dtype
dtype = get_tvm_dtype(dtype)
return T.Tensor(shape, dtype, src.data)


Expand Down
151 changes: 151 additions & 0 deletions tilelang/language/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from tilelang import tvm
from tvm import ir
import torch
import ctypes
from typing import Any


class VoidPtr:
...


AnyDType = ir.Type | str | type | torch.dtype | tvm.DataType

_dtype_torch2tvm = {
# special types should placed in the first
float: "float32",
int: "int32",
torch.long: "int64",
torch.half: "half",
# other dtypes
torch.bool: "bool",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.uint8: "uint8",
torch.uint16: "uint16",
torch.uint32: "uint32",
torch.uint64: "uint64",
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
torch.float64: "float64",
torch.float8_e4m3fn: "float8_e4m3fn",
torch.float8_e4m3fnuz: "float8_e4m3fnuz",
torch.float8_e5m2: "float8_e5m2",
torch.float8_e5m2fnuz: "float8_e5m2fnuz",
torch.float8_e8m0fnu: "float8_e8m0fnu",
}

_dtype_tvm2torch = {tvm.DataType(v): k for k, v in _dtype_torch2tvm.items()}

_dtype_tvm2ctype = {
tvm.DataType("bool"):
ctypes.c_bool,
tvm.DataType("int8"):
ctypes.c_int8,
tvm.DataType("int16"):
ctypes.c_int16,
tvm.DataType("int32"):
ctypes.c_int32,
tvm.DataType("int64"):
ctypes.c_int64,
tvm.DataType("uint8"):
ctypes.c_uint8,
tvm.DataType("uint16"):
ctypes.c_uint16,
tvm.DataType("uint32"):
ctypes.c_uint32,
tvm.DataType("uint64"):
ctypes.c_uint64,
# tvm.DataType("float16"): ctypes.c_uint16,
# tvm.DataType("bfloat16"): ctypes.c_uint16,
tvm.DataType("float32"):
ctypes.c_float,
tvm.DataType("float64"):
ctypes.c_double,
# tvm.DataType("float8_e4m3fn"): ctypes.c_uint8,
# tvm.DataType("float8_e4m3fnuz"): ctypes.c_uint8,
# tvm.DataType("float8_e5m2"): ctypes.c_uint8,
# tvm.DataType("float8_e5m2fnuz"): ctypes.c_uint8,
# tvm.DataType("float8_e8m0fnu"): ctypes.c_uint8,
tvm.DataType("handle"):
ctypes.c_void_p,
}

_dtype_tvm2cffi = {
tvm.DataType("bool"):
"bool",
tvm.DataType("int8"):
"char",
tvm.DataType("int16"):
"short",
tvm.DataType("int32"):
"int",
tvm.DataType("int64"):
"long long",
tvm.DataType("uint8"):
"unsigned char",
tvm.DataType("uint16"):
"unsigned short",
tvm.DataType("uint32"):
"unsigned int",
tvm.DataType("uint64"):
"unsigned long long",
tvm.DataType("float32"):
"float",
tvm.DataType("float64"):
"double",
# tvm.DataType("float16"): 'uint16_t',
# tvm.DataType("bfloat16"): 'uint16_t',
# tvm.DataType("float8_e4m3fn"): 'uint8_t',
# tvm.DataType("float8_e4m3fnuz"): 'uint8_t',
# tvm.DataType("float8_e5m2"): ctypes.c_uint8,
# tvm.DataType("float8_e5m2fnuz"): ctypes.c_uint8,
# tvm.DataType("float8_e8m0fnu"): ctypes.c_uint8,
tvm.DataType("handle"):
"long",
}


def get_tvm_dtype(ty: AnyDType) -> tvm.DataType:
if ty is None:
return ty
if ty == VoidPtr:
return get_tvm_ptr_type()
if isinstance(ty, (ir.Type, tvm.DataType)):
return ty
if isinstance(ty, str):
return tvm.DataType(ty)
return tvm.DataType(_dtype_torch2tvm[ty])


def get_tvm_dtype_str(ty: AnyDType) -> str:
if isinstance(ty, str):
return ty
return _dtype_torch2tvm[ty]
Comment on lines +124 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Handle tvm.DataType (and other AnyDType cases) before mapping.

get_tvm_dtype_str currently only returns early for str inputs. For every other allowed AnyDType (e.g., the tvm.DataType objects produced by get_tvm_dtype, ir.PrimType, or the new VoidPtr sentinel), this code immediately indexes _dtype_torch2tvm. Passing any of those values now raises KeyError, which breaks call sites such as call_intrin(get_tvm_dtype(torch.float32), ...) and any of the intrinsic wrappers updated in this PR. Please normalize non-string inputs first—for example, accept tvm.DataType/ir.PrimType by returning str(ty) (or ty.dtype), detect VoidPtr, and only fall back to _dtype_torch2tvm for Python/torch dtypes.

A minimal patch might look like:

 def get_tvm_dtype_str(ty: AnyDType) -> str:
-    if isinstance(ty, str):
-        return ty
-    return _dtype_torch2tvm[ty]
+    if isinstance(ty, str):
+        return ty
+    if isinstance(ty, tvm.DataType):
+        return str(ty)
+    if isinstance(ty, ir.PrimType):
+        return str(ty)
+    if isinstance(ty, VoidPtr) or ty == VoidPtr:
+        return "handle"
+    return _dtype_torch2tvm[ty]

(Feel free to tailor the exact pointer handling, but the key is to cover all members of AnyDType without throwing.)

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_tvm_dtype_str(ty: AnyDType) -> str:
if isinstance(ty, str):
return ty
return _dtype_torch2tvm[ty]
def get_tvm_dtype_str(ty: AnyDType) -> str:
if isinstance(ty, str):
return ty
if isinstance(ty, tvm.DataType):
return str(ty)
if isinstance(ty, ir.PrimType):
return str(ty)
if isinstance(ty, VoidPtr) or ty == VoidPtr:
return "handle"
return _dtype_torch2tvm[ty]
🤖 Prompt for AI Agents
In tilelang/language/dtypes.py around lines 124 to 127, get_tvm_dtype_str
currently only special-cases str and then indexes _dtype_torch2tvm, which raises
KeyError for tvm.DataType, ir.PrimType, VoidPtr, and other AnyDType variants;
change it to first normalize non-string inputs: if ty is tvm.DataType or
ir.PrimType (or has a .dtype attr) return str(ty) or ty.dtype as appropriate,
detect the VoidPtr sentinel and return the proper pointer/type string, and only
if ty is a Python/torch dtype fall back to looking up _dtype_torch2tvm; this
prevents KeyError and covers all AnyDType members before mapping.



def get_torch_dtype(ty: AnyDType) -> torch.dtype:
if isinstance(ty, torch.dtype):
return ty
if isinstance(ty, str):
ty = tvm.DataType(ty)
return _dtype_tvm2torch[ty]


def get_ctypes_dtype(ty: AnyDType) -> Any:
ty = get_tvm_dtype(ty)
return _dtype_tvm2ctype[ty]


def get_cffi_dtype(ty: AnyDType) -> str:
ty = get_tvm_dtype(ty)
return _dtype_tvm2cffi[ty]


def get_tvm_ptr_type(ty: ir.Type | str | type | torch.dtype = "void",
scope: str = "global") -> ir.PointerType:
ty = get_tvm_dtype(ty)
return ir.PointerType(ir.PrimType(ty), scope)
4 changes: 4 additions & 0 deletions tilelang/language/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tvm.tir import Var, PrimExpr
from tvm.script.ir_builder.tir import buffer, handle, match_buffer
from tilelang.utils import deprecated
from .dtypes import get_tvm_dtype


class BufferProxy:
Expand Down Expand Up @@ -295,11 +296,14 @@ def ptr(dtype: Optional[str] = None,
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
if dtype is not None:
dtype = get_tvm_dtype(dtype)
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)


def make_tensor(ptr: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
dtype = get_tvm_dtype(dtype)
return Tensor.from_ptr(ptr, shape, dtype, strides)
Loading
Loading