diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index c872985f9..a42ba5a67 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -157,8 +157,42 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var def __dtype_as_torch__(self: dtype) -> torch.dtype: """Convert TileLang dtype to PyTorch dtype.""" dtype_str = str(self) - if dtype_str in _STR_TO_TORCH_DTYPE: + + if dtype_str == "float8_e4m3": + # Check if we're on HIP (AMD ROCm) or CUDA + if torch.version.hip is not None: + # HIP backend - use float8_e4m3fnuz + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0" + ) + return torch.float8_e4m3fnuz + else: + # CUDA backend - use float8_e4m3fn + assert hasattr(torch, "float8_e4m3fn"), ( + "torch.float8_e4m3fn is not supported in this version of torch. Please upgrade torch >= 2.1.0" + ) + return torch.float8_e4m3fn + elif dtype_str == "float8_e5m2": + assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0" + return torch.float8_e5m2 + elif dtype_str == "e4m3fnuz_float8": + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0" + ) + return torch.float8_e4m3fnuz + elif dtype_str == "float8_e8m0fnu": + assert hasattr(torch, "float8_e8m0fnu"), ( + "torch.float8_e8m0fnu is not supported in this version of torch. Please upgrade torch >= 2.8.0" + ) + return torch.float8_e8m0fnu + elif dtype_str == "float4_e2m1fnx2": + assert hasattr(torch, "float4_e2m1fnx2"), ( + "torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0" + ) + return torch.float4_e2m1fnx2 + elif dtype_str in _STR_TO_TORCH_DTYPE: return _STR_TO_TORCH_DTYPE[dtype_str] + raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")