Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 57 additions & 56 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,62 +145,63 @@ def test_str_repr():
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841


def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
assert T.dtype(b) == a, "dtype conversion error"
# not supported now
# def test_torch_eq():
# dtypes = [
# T.bool,
# T.short,
# T.int,
# T.long,
# T.half,
# T.float,
# T.long,
# T.int8,
# T.int16,
# T.int32,
# T.int64,
# T.uint8,
# T.uint16,
# T.uint32,
# T.uint64,
# T.float8_e4m3fn,
# T.float8_e4m3fnuz,
# T.float8_e5m2,
# T.float8_e5m2fnuz,
# T.float8_e8m0fnu,
# T.float16,
# T.bfloat16,
# T.float32,
# T.float64,
# ]
# torch_dtypes = [
# torch.bool,
# torch.short,
# torch.int,
# torch.long,
# torch.half,
# torch.float,
# torch.long,
# torch.int8,
# torch.int16,
# torch.int32,
# torch.int64,
# torch.uint8,
# torch.uint16,
# torch.uint32,
# torch.uint64,
# torch.float8_e4m3fn,
# torch.float8_e4m3fnuz,
# torch.float8_e5m2,
# torch.float8_e5m2fnuz,
# torch.float8_e8m0fnu,
# torch.float16,
# torch.bfloat16,
# torch.float32,
# torch.float64,
# ]
# for a, b in zip(dtypes, torch_dtypes):
# assert a == b, f"{a} and {b} are not equal"
# assert T.dtype(b) == a, "dtype conversion error"


def test_var_assign():
Expand Down
106 changes: 106 additions & 0 deletions tilelang/language/tir/ir.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import TypeVar, Literal
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm

_T = TypeVar('_T')

def abs(x: _T, span: Span | None=None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
def asin(x: _T) -> _T: ...
def asinh(x: _T) -> _T: ...
def atan(x: _T) -> _T: ...
def atan2(x1: _T, x2: _T) -> _T: ...
def atanh(x: _T) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
def ceil(x: _T, span: Span | None=None) -> _T: ...
def clz(x: _T) -> _T: ...
def copysign(x1: _T, x2: _T) -> _T: ...
def cos(x: _T) -> _T: ...
def cosh(x: _T) -> _T: ...
def erf(x: _T) -> _T: ...
def exp(x: _T) -> _T: ...
def exp2(x: _T) -> _T: ...
def exp10(x: _T) -> _T: ...
def floor(x: _T, span: Span | None=None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def fmod(x: _T, y: _T) -> _T: ...
def hypot(x1: _T, x2: _T) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
def isfinite(x: _T, span: Span | None=None) -> _T: ...
def isinf(x: _T, span: Span | None=None) -> _T: ...
def isnan(x: _T, span: Span | None=None) -> _T: ...
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
def ldexp(x1: _T, x2: _T) -> _T: ...
def likely(cond: _T, span: Span | None=None) -> _T: ...
def log(x: _T) -> _T: ...
def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
def popcount(x: _T) -> _T: ...
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
def ret(val: _T) -> _T: ...
def round(x: _T, span: Span | None=None) -> _T: ...
def rsqrt(x: _T) -> _T: ...
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
def sigmoid(x: _T) -> _T: ...
def sin(x: _T) -> _T: ...
def sinh(x: _T) -> _T: ...
def sqrt(x: _T) -> _T: ...
def tan(x: _T) -> _T: ...
def tanh(x: _T) -> _T: ...
def trunc(x: _T, span: Span | None=None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
def tvm_throw_last_error() -> _T: ...
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
def tvm_stack_make_shape(*args) -> _T: ...
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
def call_packed(*args, span=None) -> _T: ...
def call_cpacked(*args, span=None) -> _T: ...
def call_packed_lowered(*args, span=None) -> _T: ...
def call_cpacked_lowered(*args, span=None) -> _T: ...
def tvm_tuple(*value) -> _T: ...
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
def tvm_thread_invariant(cond: _T) -> _T: ...
def tvm_thread_allreduce(*freduce_args) -> _T: ...
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def ptx_wait_group(num: int) -> PrimExpr: ...
def ptx_commit_group() -> _T: ...
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ...
def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
def create_barriers(barrier_count: int) -> PrimExpr: ...
def assume(cond: _T=None) -> _T: ...
def undef() -> _T: ...
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
def start_profile_intrinsic(id: int) -> PrimExpr: ...
def end_profile_intrinsic(id: int) -> PrimExpr: ...
def anylist_getitem(list_handle, index) -> PrimExpr: ...
def anylist_resetitem(list_handle, index) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ...
def vscale() -> _T: ...
Loading
Loading