Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 103 additions & 243 deletions src/target/codegen_cuda.cc

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions src/target/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,47 @@ int TargetGetWarpSize(Target target) {
return res;
}

bool IsCudaVectorizableFP8(DataType dtype) {
return dtype.is_float8_e4m3() || dtype.is_float8_e4m3fn() ||
dtype.is_float8_e5m2();
}

bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) {
// float16 -> float32
if (from_ty.is_float16() && target_ty.is_float())
return true;

// float32 -> float16
if (from_ty.is_float() && target_ty.is_float16())
return true;

// bfloat16 -> float32
if (from_ty.is_bfloat16() && target_ty.is_float())
return true;

// float32 -> bfloat16
if (from_ty.is_float() && target_ty.is_bfloat16())
return true;

// float32 -> float8 (E4M3/E5M2)
if (from_ty.is_float() && IsCudaVectorizableFP8(target_ty))
return true;

// float8 (E4M3/E5M2) -> float32
if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float())
return true;

// float4_e2m1fn -> float32
if (from_ty.is_float4_e2m1fn() && target_ty.is_float())
return true;

// float32 -> float4_e2m1fn
if (from_ty.is_float() && target_ty.is_float4_e2m1fn())
return true;

return false;
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
Expand Down
3 changes: 3 additions & 0 deletions src/target/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ bool TargetHasTmem(Target target);
bool TargetHasBulkCopy(Target target);
int TargetGetWarpSize(Target target);

bool IsCudaVectorizableFP8(DataType dtype);
bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty);

} // namespace tl
} // namespace tvm

Expand Down
48 changes: 48 additions & 0 deletions src/tl_templates/cuda/cuda_fp4.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,52 @@ TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t(
return result;
}

// fp4_e2m1x2 (1 byte) -> half2
// Uses PTX cvt.rn.f16x2.e2m1x2 instruction
TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const uint8_t src) {
half2 out;
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(&out);
uint16_t src_packed = static_cast<uint16_t>(src);
asm volatile("{\n"
".reg .b8 byte0, byte1;\n"
"mov.b16 {byte0, byte1}, %1;\n"
"cvt.rn.f16x2.e2m1x2 %0, byte0;\n"
"}\n"
: "=r"(*out_ptr)
: "h"(src_packed));
return out;
}

// fp4_e2m1x2 (1 byte) -> float2
TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const uint8_t src) {
half2 tmp = __tl_cvt_fp4x2_to_half2(src);
float2 result;
result.x = __half2float(tmp.x);
result.y = __half2float(tmp.y);
return result;
}

// half2 -> fp4_e2m1x2 (1 byte)
// Uses PTX cvt.rn.satfinite.e2m1x2.f16x2 instruction
TL_DEVICE uint8_t __tl_cvt_half2_to_fp4x2(const half2 src) {
uint16_t out;
uint32_t const *src_ptr = reinterpret_cast<uint32_t const *>(&src);
asm volatile("{\n"
".reg .b8 result_byte;\n"
"cvt.rn.satfinite.e2m1x2.f16x2 result_byte, %1;\n"
"mov.b16 %0, {result_byte, 0};\n"
"}\n"
: "=h"(out)
: "r"(*src_ptr));
return static_cast<uint8_t>(out);
}

// float2 -> fp4_e2m1x2 (1 byte)
TL_DEVICE uint8_t __tl_cvt_float2_to_fp4x2(const float2 src) {
half2 tmp;
tmp.x = __float2half(src.x);
tmp.y = __float2half(src.y);
return __tl_cvt_half2_to_fp4x2(tmp);
}

#endif
11 changes: 4 additions & 7 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1177,13 +1177,10 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok =
src_type.is_float() || src_type.is_bfloat() || src_type.is_float8();
bool dst_ok =
dst_type.is_float() || dst_type.is_bfloat() || dst_type.is_float8();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
DataType from_ty = cast->value.dtype();
DataType target_ty = cast->dtype;
if (IsCudaVectorizableCast(from_ty, target_ty) &&
TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
}
Expand Down
14 changes: 10 additions & 4 deletions testing/python/debug/test_tilelang_debug_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ def program(Q: T.Tensor((M, N), dtype)):
shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf)

jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
jit_kernel = tilelang.compile(program)
profiler = jit_kernel.get_profiler()
profiler.run_once()


def test_debug_print_buffer():
debug_print_buffer(dtype=T.bool)
debug_print_buffer(dtype=T.int8)
debug_print_buffer(dtype=T.int16)
debug_print_buffer(dtype=T.int32)
Expand All @@ -31,10 +30,17 @@ def test_debug_print_buffer():
debug_print_buffer(dtype=T.float32)
debug_print_buffer(dtype=T.float64)
debug_print_buffer(dtype=T.bfloat16)


@tilelang.testing.requires_cuda
def test_debug_print_buffer_cuda_fp8():
debug_print_buffer(dtype=T.float8_e4m3fn)
debug_print_buffer(dtype=T.float8_e4m3fn)
debug_print_buffer(dtype=T.float8_e4m3fnuz)
debug_print_buffer(dtype=T.float8_e5m2)


@tilelang.testing.requires_rocm
def test_debug_print_buffer_rocm_fp8():
debug_print_buffer(dtype=T.float8_e4m3fnuz)
debug_print_buffer(dtype=T.float8_e5m2fnuz)


Expand Down
71 changes: 47 additions & 24 deletions testing/python/language/test_tilelang_language_vectorized_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
import tilelang.testing
import tilelang.language as T

str2dtype = {
T.float32: torch.float32,
T.float16: torch.float16,
T.bfloat16: torch.bfloat16,
T.float8_e4m3fn: torch.float8_e4m3fn,
T.float8_e5m2: torch.float8_e5m2,
}


@tilelang.jit
def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
Expand Down Expand Up @@ -48,34 +40,39 @@ def main(
return main


def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str, lanes: int = 2):
"""Run the vectorized cast kernel and check the correctness.
Args:
src_dtype_str: The source data type string.
dst_dtype_str: The destination data type string.
src_dtype: The source data type.
dst_dtype: The destination data type.
check_str: Used to ensure vectorized cast is used.
lanes: The number of lanes of the source and destination data types.
"""

M = 128 * lanes
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel = vectorized_cast_kernel(M, src_dtype, dst_dtype)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype, dst_dtype)

code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
print(code)
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!"

if src_dtype == T.float4_e2m1fn or dst_dtype == T.float4_e2m1fn:
return

A_float = torch.randn(M, dtype=torch.float32, device="cuda")
A = A_float.to(str2dtype[src_dtype_str])
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda")
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda")
A = A_float.to(src_dtype.as_torch())

A = A_float.to(src_dtype.as_torch())
B = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda")
C = torch.zeros(M, dtype=dst_dtype.as_torch(), device="cuda")

kernel(A, B)
kernel_parallel(A, C)

torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)

code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()

assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
torch.testing.assert_close(A.to(dst_dtype.as_torch()), B)
torch.testing.assert_close(A.to(dst_dtype.as_torch()), C)


@pytest.mark.parametrize(
Expand All @@ -93,13 +90,39 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
(T.float32, T.bfloat16, "__float22bfloat162_rn", 4),
(T.bfloat16, T.float32, "__bfloat1622float2", 2),
(T.bfloat16, T.float32, "__bfloat1622float2", 4),
],
)
def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 9)
@pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes",
[
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
],
)
def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes):
run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0)
@pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes",
[
(T.float4_e2m1fn, T.float16, "__tl_cvt_fp4x2_to_half2", 2),
(T.float16, T.float4_e2m1fn, "__tl_cvt_half2_to_fp4x2", 2),
(T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2),
(T.float32, T.float4_e2m1fn, "__tl_cvt_float2_to_fp4x2", 2),
],
)
def test_vectorized_cast_fp4(src_dtype, dst_dtype, check_str, lanes):
run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes)


Expand Down
8 changes: 6 additions & 2 deletions tilelang/language/v2/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
import numpy as np
from tilelang import logger

_T = TypeVar("_T")

Expand Down Expand Up @@ -175,7 +176,7 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype:
elif dtype_str == "float8_e5m2":
assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0"
return torch.float8_e5m2
elif dtype_str == "e4m3fnuz_float8":
elif dtype_str == "float8_e4m3fnuz":
assert hasattr(torch, "float8_e4m3fnuz"), (
"torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0"
)
Expand All @@ -189,7 +190,10 @@ def __dtype_as_torch__(self: dtype) -> torch.dtype:
assert hasattr(torch, "float4_e2m1fnx2"), (
"torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0"
)
return torch.float4_e2m1fnx2
return torch.float4_e2m1fn_x2
elif dtype_str == "float4_e2m1fn":
logger.info("torch doesn't support float4_e2m1fn, using float4_e2m1fnx2 as storage dtype.")
return torch.float4_e2m1fn_x2
Comment on lines 190 to +196

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Return the correct torch dtype attribute name

The new float4 mapping checks hasattr(torch, "float4_e2m1fnx2") but then returns torch.float4_e2m1fn_x2 (note the extra underscore). When torch supports the dtype, this will still raise AttributeError at runtime because the returned attribute doesn’t exist, and it affects both the float4_e2m1fnx2 path and the fallback for float4_e2m1fn. This breaks any code paths that request those dtypes on supported torch versions.

Useful? React with 👍 / 👎.

elif dtype_str in _STR_TO_TORCH_DTYPE:
return _STR_TO_TORCH_DTYPE[dtype_str]

Expand Down
Loading