Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tilelang/language/v2/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,42 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
def __dtype_as_torch__(self: dtype) -> torch.dtype:
"""Convert TileLang dtype to PyTorch dtype."""
dtype_str = str(self)
if dtype_str in _STR_TO_TORCH_DTYPE:

if dtype_str == "float8_e4m3":
# Check if we're on HIP (AMD ROCm) or CUDA
if torch.version.hip is not None:
# HIP backend - use float8_e4m3fnuz
assert hasattr(torch, "float8_e4m3fnuz"), (
"torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0"
)
return torch.float8_e4m3fnuz
else:
# CUDA backend - use float8_e4m3fn
assert hasattr(torch, "float8_e4m3fn"), (
"torch.float8_e4m3fn is not supported in this version of torch. Please upgrade torch >= 2.1.0"
)
return torch.float8_e4m3fn
elif dtype_str == "float8_e5m2":
assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0"
return torch.float8_e5m2
elif dtype_str == "e4m3fnuz_float8":
assert hasattr(torch, "float8_e4m3fnuz"), (
"torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0"
)
return torch.float8_e4m3fnuz
elif dtype_str == "float8_e8m0fnu":
assert hasattr(torch, "float8_e8m0fnu"), (
"torch.float8_e8m0fnu is not supported in this version of torch. Please upgrade torch >= 2.8.0"
)
return torch.float8_e8m0fnu
elif dtype_str == "float4_e2m1fnx2":
assert hasattr(torch, "float4_e2m1fnx2"), (
"torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0"
)
return torch.float4_e2m1fnx2
elif dtype_str in _STR_TO_TORCH_DTYPE:
return _STR_TO_TORCH_DTYPE[dtype_str]

raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")


Expand Down
Loading