Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_argument(


def test_expr():
from tilelang.language.eager.dtypes import _all_dtypes
from tilelang.language.dtypes import _all_dtypes

errors = []
for name in _all_dtypes:
Expand Down
2 changes: 1 addition & 1 deletion tilelang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _load_tile_lang_lib():
engine, # noqa: F401
tools, # noqa: F401
)
from .language.eager import dtypes # noqa: F401
from .language import dtypes # noqa: F401
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401
from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401
Expand Down
3 changes: 3 additions & 0 deletions tilelang/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Re-export from language.dtypes for convenient access via `from tilelang.dtypes import ...`
from tilelang.language.dtypes import * # noqa: F401, F403
from tilelang.language.dtypes import dtype, AnyDType, get_tvm_dtype # noqa: F401
2 changes: 1 addition & 1 deletion tilelang/jit/adapter/tvm_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.utils.language import retrieve_func_from_module
from tilelang.engine.param import KernelParam
from tilelang.language.eager.dtypes import dtype
from tilelang.language.dtypes import dtype


class TVMFFIKernelAdapter(BaseKernelAdapter):
Expand Down
4 changes: 2 additions & 2 deletions tilelang/language/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
from .eager import dtypes as _dtypes
from .eager.dtypes import dtype as tl_dtype
from . import dtypes as _dtypes
from .dtypes import dtype as tl_dtype
from .eager.builder import OutTensor

_Shapes = TypeVarTuple("_Shapes")
Expand Down
10 changes: 10 additions & 0 deletions tilelang/language/eager/dtypes.py → tilelang/language/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
if TYPE_CHECKING:

class dtype(Generic[_T]):
@property
def bits(self) -> int: ...
@property
def bytes(self) -> int: ...
def as_torch(self) -> torch.dtype: ...
else:
dtype = tvm.DataType
Expand Down Expand Up @@ -218,9 +222,15 @@ def __dtype_new__(cls, value: AnyDType) -> dtype:
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")


def __dtype_bytes__(self: dtype) -> int:
"""Return the number of bytes for this dtype."""
return self.itemsize


dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__
dtype.as_torch = __dtype_as_torch__
dtype.bytes = property(__dtype_bytes__)


def get_tvm_dtype(value: AnyDType) -> dtype:
Expand Down
2 changes: 1 addition & 1 deletion tilelang/language/eager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .builder import prim_func, macro, PrimFunc, JITFunc, Ref, const # noqa: F401
from .dtypes import *
from ..dtypes import *
2 changes: 1 addition & 1 deletion tilelang/language/eager/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# from .utils import get_ast, get_compiled_object
from . import utils
from . import dtypes
from .. import dtypes

_span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"]

Expand Down
2 changes: 1 addition & 1 deletion tilelang/language/eager/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import ParamSpec, Self
except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec
from typing_extensions import ParamSpec, Self
from . import dtypes as dt
from .. import dtypes as dt
from . import utils
from tilelang.jit.exceptions import JITNoBuilderError, EagerJITBuildError
import threading
Expand Down
Loading