diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 915574c3e..bad009e66 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -145,62 +145,63 @@ def test_str_repr(): buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 -def test_torch_eq(): - dtypes = [ - T.bool, - T.short, - T.int, - T.long, - T.half, - T.float, - T.long, - T.int8, - T.int16, - T.int32, - T.int64, - T.uint8, - T.uint16, - T.uint32, - T.uint64, - T.float8_e4m3fn, - T.float8_e4m3fnuz, - T.float8_e5m2, - T.float8_e5m2fnuz, - T.float8_e8m0fnu, - T.float16, - T.bfloat16, - T.float32, - T.float64, - ] - torch_dtypes = [ - torch.bool, - torch.short, - torch.int, - torch.long, - torch.half, - torch.float, - torch.long, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.uint16, - torch.uint32, - torch.uint64, - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - torch.float8_e5m2, - torch.float8_e5m2fnuz, - torch.float8_e8m0fnu, - torch.float16, - torch.bfloat16, - torch.float32, - torch.float64, - ] - for a, b in zip(dtypes, torch_dtypes): - assert a == b, f"{a} and {b} are not equal" - assert T.dtype(b) == a, "dtype conversion error" +# not supported now +# def test_torch_eq(): +# dtypes = [ +# T.bool, +# T.short, +# T.int, +# T.long, +# T.half, +# T.float, +# T.long, +# T.int8, +# T.int16, +# T.int32, +# T.int64, +# T.uint8, +# T.uint16, +# T.uint32, +# T.uint64, +# T.float8_e4m3fn, +# T.float8_e4m3fnuz, +# T.float8_e5m2, +# T.float8_e5m2fnuz, +# T.float8_e8m0fnu, +# T.float16, +# T.bfloat16, +# T.float32, +# T.float64, +# ] +# torch_dtypes = [ +# torch.bool, +# torch.short, +# torch.int, +# torch.long, +# torch.half, +# torch.float, +# torch.long, +# torch.int8, +# torch.int16, +# torch.int32, +# torch.int64, +# torch.uint8, +# torch.uint16, +# torch.uint32, +# torch.uint64, +# torch.float8_e4m3fn, +# torch.float8_e4m3fnuz, +# torch.float8_e5m2, +# torch.float8_e5m2fnuz, +# torch.float8_e8m0fnu, +# torch.float16, +# torch.bfloat16, +# torch.float32, +# torch.float64, +# ] +# for a, b in zip(dtypes, torch_dtypes): +# assert a == b, f"{a} and {b} are not equal" +# assert T.dtype(b) == a, "dtype conversion error" def test_var_assign(): diff --git a/tilelang/language/tir/ir.pyi b/tilelang/language/tir/ir.pyi new file mode 100644 index 000000000..fe25b58f8 --- /dev/null +++ b/tilelang/language/tir/ir.pyi @@ -0,0 +1,106 @@ +from typing import TypeVar, Literal +from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm + +_T = TypeVar('_T') + +def abs(x: _T, span: Span | None=None) -> _T: ... +def acos(x: _T) -> _T: ... +def acosh(x: _T) -> _T: ... +def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ... +def asin(x: _T) -> _T: ... +def asinh(x: _T) -> _T: ... +def atan(x: _T) -> _T: ... +def atan2(x1: _T, x2: _T) -> _T: ... +def atanh(x: _T) -> _T: ... +def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ... +def bitwise_not(x: _T, span: Span | None=None) -> _T: ... +def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ... +def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ... +def ceil(x: _T, span: Span | None=None) -> _T: ... +def clz(x: _T) -> _T: ... +def copysign(x1: _T, x2: _T) -> _T: ... +def cos(x: _T) -> _T: ... +def cosh(x: _T) -> _T: ... +def erf(x: _T) -> _T: ... +def exp(x: _T) -> _T: ... +def exp2(x: _T) -> _T: ... +def exp10(x: _T) -> _T: ... +def floor(x: _T, span: Span | None=None) -> _T: ... +def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ... +def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ... +def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def fmod(x: _T, y: _T) -> _T: ... +def hypot(x1: _T, x2: _T) -> _T: ... +def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ... +def infinity(dtype: _T, span: Span | None=None) -> _T: ... +def isfinite(x: _T, span: Span | None=None) -> _T: ... +def isinf(x: _T, span: Span | None=None) -> _T: ... +def isnan(x: _T, span: Span | None=None) -> _T: ... +def isnullptr(x: _T, span: Span | None=None) -> _T: ... +def ldexp(x1: _T, x2: _T) -> _T: ... +def likely(cond: _T, span: Span | None=None) -> _T: ... +def log(x: _T) -> _T: ... +def log1p(x: _T) -> _T: ... +def log2(x: _T) -> _T: ... +def log10(x: _T) -> _T: ... +def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ... +def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ... +def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ... +def nearbyint(x: _T, span: Span | None=None) -> _T: ... +def nextafter(x1: _T, x2: _T) -> _T: ... +def popcount(x: _T) -> _T: ... +def pow(x: _T, y: _T, span: Span | None=None) -> _T: ... +def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ... +def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ... +def ret(val: _T) -> _T: ... +def round(x: _T, span: Span | None=None) -> _T: ... +def rsqrt(x: _T) -> _T: ... +def shift_left(x: _T, y: _T, span=None) -> _T: ... +def shift_right(x: _T, y: _T, span=None) -> _T: ... +def sigmoid(x: _T) -> _T: ... +def sin(x: _T) -> _T: ... +def sinh(x: _T) -> _T: ... +def sqrt(x: _T) -> _T: ... +def tan(x: _T) -> _T: ... +def tanh(x: _T) -> _T: ... +def trunc(x: _T, span: Span | None=None) -> _T: ... +def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ... +def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... +def tvm_throw_last_error() -> _T: ... +def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... +def tvm_stack_make_shape(*args) -> _T: ... +def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ... +def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ... +def call_packed(*args, span=None) -> _T: ... +def call_cpacked(*args, span=None) -> _T: ... +def call_packed_lowered(*args, span=None) -> _T: ... +def call_cpacked_lowered(*args, span=None) -> _T: ... +def tvm_tuple(*value) -> _T: ... +def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... +def tvm_thread_invariant(cond: _T) -> _T: ... +def tvm_thread_allreduce(*freduce_args) -> _T: ... +def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... +def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... +def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... +def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ... +def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... +def ptx_wait_group(num: int) -> PrimExpr: ... +def ptx_commit_group() -> _T: ... +def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... +def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ... +def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... +def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... +def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... +def create_barriers(barrier_count: int) -> PrimExpr: ... +def assume(cond: _T=None) -> _T: ... +def undef() -> _T: ... +def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ... +def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ... +def start_profile_intrinsic(id: int) -> PrimExpr: ... +def end_profile_intrinsic(id: int) -> PrimExpr: ... +def anylist_getitem(list_handle, index) -> PrimExpr: ... +def anylist_resetitem(list_handle, index) -> PrimExpr: ... +def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ... +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ... +def vscale() -> _T: ... diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 2161e3770..0702635a0 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -1,95 +1,98 @@ from tilelang import tvm from tvm import ir import torch -import ctypes from typing import TYPE_CHECKING, Union from tvm import tir import tvm.script.ir_builder.tir._ffi_api as tb_ffi +import numpy as np dtype = tvm.DataType # Python 3.9 compatibility: avoid PEP 604 unions at runtime AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] -# Base dtype conversion list -_dtype_cvt_base = [ - (None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* - (bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), - (int, 'int32', ctypes.c_int32, 'int', 'Int32'), - (float, 'float32', ctypes.c_float, 'float', 'Float32'), - (torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'), - (torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'), - (torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'), - (torch.half, 'float16', None, None, 'Float16'), - (torch.float, 'float32', ctypes.c_float, 'float', 'Float32'), - (torch.double, 'float64', ctypes.c_double, 'double', 'Float64'), - - # (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype') - (torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), - (torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'), - (torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'), - (torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'), - (torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'), - (torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'), - (torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'), - (torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'), - (torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'), - (torch.float16, 'float16', None, None, 'Float16'), - (torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'), - (torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'), - (None, 'float8_e4m3', None, None, 'Float8E4M3'), - (torch.bfloat16, 'bfloat16', None, None, 'BFloat16'), -] - -# Dynamically add fp8-related types if they exist in torch -_fp8_dtype_mappings = [ - ('float8_e4m3fn', 'Float8E4M3FN'), - ('float8_e4m3fnuz', 'Float8E4M3FNUZ'), - ('float8_e5m2', 'Float8E5M2'), - ('float8_e5m2fnuz', 'Float8E5M2FNUZ'), - ('float8_e8m0fnu', 'Float8E8M0FNU'), -] - -_dtype_cvt = list(_dtype_cvt_base) -for torch_attr_name, tvm_name in _fp8_dtype_mappings: - if hasattr(torch, torch_attr_name): - torch_dtype = getattr(torch, torch_attr_name) - _dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name)) - +_PYTHON_DTYPE_TO_STR = { + bool: 'bool', + int: 'int32', + float: 'float32', +} -def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): - return { - smapper(item[sidx]): dmapper(item[didx]) - for item in _dtype_cvt - if item[didx] is not None and item[sidx] is not None - } +_NUMPY_DTYPE_TO_STR = { + np.bool_: 'bool', + np.short: 'int16', + np.int_: 'int64', + np.longlong: 'int64', + np.half: 'float16', + np.double: 'float64', + np.int8: 'int8', + np.int16: 'int16', + np.int32: 'int32', + np.int64: 'int64', + np.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + np.float16: 'float16', + np.float32: 'float32', + np.float64: 'float64', +} +_NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()}) -_dtype_py2tvmstr = _create_type_mapper(0, 1) -_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x)) -_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x)) -_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x)) -_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x)) +_TORCH_DTYPE_TO_STR = { + torch.bool: 'bool', + torch.short: 'int16', + torch.int: 'int32', + torch.long: 'int64', + torch.half: 'float16', + torch.float: 'float32', + torch.double: 'float64', + torch.int8: 'int8', + torch.int16: 'int16', + torch.int32: 'int32', + torch.int64: 'int64', + torch.uint8: 'uint8', + torch.uint16: 'uint16', + torch.uint32: 'uint32', + torch.uint64: 'uint64', + torch.float16: 'float16', + torch.float32: 'float32', + torch.float64: 'float64', + torch.bfloat16: 'bfloat16', +} +# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} -def __dtype_eq__(self: dtype, other: AnyDType): - if isinstance(other, str): - return str.__eq__(self, other) - if other in _dtype_py2tvmstr: - return str.__eq__(self, _dtype_py2tvmstr[other]) - return NotImplemented +# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()} +_DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR} -def __dtype_ne__(self: dtype, other: AnyDType): - if isinstance(other, str): - return str.__ne__(self, other) - if other in _dtype_py2tvmstr: - return str.__ne__(self, _dtype_py2tvmstr[other]) - return NotImplemented +_STR_TO_TVM_DTYPE_CALL = { + 'bool': 'Boolean', + 'int8': 'Int8', + 'int32': 'Int32', + 'int64': 'Int64', + 'uint8': 'UInt8', + 'uint16': 'UInt16', + 'uint32': 'UInt32', + 'uint64': 'UInt64', + 'float16': 'Float16', + 'float32': 'Float32', + 'float64': 'Float64', + 'bfloat16': 'BFloat16', + 'float8_e4m3': 'Float8E4M3', + 'float8_e4m3fn': 'Float8E4M3FN', + 'float8_e4m3fnuz': 'Float8E4M3FNUZ', + 'float8_e5m2': 'Float8E5M2', + 'float8_e5m2fnuz': 'Float8E5M2FNUZ', + 'float8_e8m0fnu': 'Float8E8M0FNU' +} def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: - if self in _dtype_tvmstr2fficall: - return _dtype_tvmstr2fficall[self](expr, is_size_var) + if self in _STR_TO_TVM_DTYPE_CALL: + attr = _STR_TO_TVM_DTYPE_CALL[self] + call = getattr(tb_ffi, attr, None) + return call(expr, is_size_var) # try to construct the ffi call if self.startswith('uint'): val = 'UInt' + self[4:] @@ -117,17 +120,13 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var def __dtype_new__(cls, value: AnyDType) -> dtype: if isinstance(value, str): return __orig_dtype_new(cls, value) - elif value in _dtype_py2tvmstr: - return __orig_dtype_new(cls, _dtype_py2tvmstr[value]) + elif value in _DTYPE_TO_STR: + return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) else: - expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) + expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values())) raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") -dtype.__eq__ = __dtype_eq__ -dtype.__req__ = __dtype_eq__ -dtype.__ne__ = __dtype_ne__ -dtype.__rne__ = __dtype_ne__ dtype.__call__ = __dtype_call__ dtype.__new__ = __dtype_new__