diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 010ddbe8e..8966b36ed 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -40,7 +40,7 @@ def test_argument( def test_expr(): - from tilelang.language.eager.dtypes import _all_dtypes + from tilelang.language.dtypes import _all_dtypes errors = [] for name in _all_dtypes: diff --git a/tilelang/__init__.py b/tilelang/__init__.py index a41c40a5b..03582b743 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -160,7 +160,7 @@ def _load_tile_lang_lib(): engine, # noqa: F401 tools, # noqa: F401 ) - from .language.eager import dtypes # noqa: F401 + from .language import dtypes # noqa: F401 from .autotuner import autotune # noqa: F401 from .transform import PassConfigKey # noqa: F401 from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401 diff --git a/tilelang/dtypes.py b/tilelang/dtypes.py new file mode 100644 index 000000000..3e12d285d --- /dev/null +++ b/tilelang/dtypes.py @@ -0,0 +1,3 @@ +# Re-export from language.dtypes for convenient access via `from tilelang.dtypes import ...` +from tilelang.language.dtypes import * # noqa: F401, F403 +from tilelang.language.dtypes import dtype, AnyDType, get_tvm_dtype # noqa: F401 diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 73ff779bf..755dc8dbf 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -19,7 +19,7 @@ from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.utils.language import retrieve_func_from_module from tilelang.engine.param import KernelParam -from tilelang.language.eager.dtypes import dtype +from tilelang.language.dtypes import dtype class TVMFFIKernelAdapter(BaseKernelAdapter): diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 6100316dd..c14a733b7 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -28,8 +28,8 @@ from tvm.script.parser.tir import block_attr from tvm.tir.buffer import Buffer from tvm.tir.expr import FloatImm, IntImm -from .eager import dtypes as _dtypes -from .eager.dtypes import dtype as tl_dtype +from . import dtypes as _dtypes +from .dtypes import dtype as tl_dtype from .eager.builder import OutTensor _Shapes = TypeVarTuple("_Shapes") diff --git a/tilelang/language/eager/dtypes.py b/tilelang/language/dtypes.py similarity index 98% rename from tilelang/language/eager/dtypes.py rename to tilelang/language/dtypes.py index a29c57ff9..bd69c118a 100644 --- a/tilelang/language/eager/dtypes.py +++ b/tilelang/language/dtypes.py @@ -12,6 +12,10 @@ if TYPE_CHECKING: class dtype(Generic[_T]): + @property + def bits(self) -> int: ... + @property + def bytes(self) -> int: ... def as_torch(self) -> torch.dtype: ... else: dtype = tvm.DataType @@ -218,9 +222,15 @@ def __dtype_new__(cls, value: AnyDType) -> dtype: raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") +def __dtype_bytes__(self: dtype) -> int: + """Return the number of bytes for this dtype.""" + return self.itemsize + + dtype.__call__ = __dtype_call__ dtype.__new__ = __dtype_new__ dtype.as_torch = __dtype_as_torch__ +dtype.bytes = property(__dtype_bytes__) def get_tvm_dtype(value: AnyDType) -> dtype: diff --git a/tilelang/language/eager/__init__.py b/tilelang/language/eager/__init__.py index cf760fed9..171068126 100644 --- a/tilelang/language/eager/__init__.py +++ b/tilelang/language/eager/__init__.py @@ -1,2 +1,2 @@ from .builder import prim_func, macro, PrimFunc, JITFunc, Ref, const # noqa: F401 -from .dtypes import * +from ..dtypes import * diff --git a/tilelang/language/eager/ast.py b/tilelang/language/eager/ast.py index 18d071b13..230925a1d 100644 --- a/tilelang/language/eager/ast.py +++ b/tilelang/language/eager/ast.py @@ -15,7 +15,7 @@ # from .utils import get_ast, get_compiled_object from . import utils -from . import dtypes +from .. import dtypes _span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"] diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index b3efb645c..f7377a222 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -27,7 +27,7 @@ from typing import ParamSpec, Self except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec from typing_extensions import ParamSpec, Self -from . import dtypes as dt +from .. import dtypes as dt from . import utils from tilelang.jit.exceptions import JITNoBuilderError, EagerJITBuildError import threading