Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def get_configs():
return configs


@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1])
def flashattn(
batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
Expand Down
57 changes: 52 additions & 5 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
};

// Handle conversion from float16 to float32
if (from_ty.is_float16() && target_ty.is_float()) {
if (from_ty.is_float16() && target_ty.is_float() && target_ty.bits() == 32) {
// Use __half22float2 for vectorized conversion (half2 -> float2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__half22float2", "half2", "float2");
Expand All @@ -1004,7 +1004,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from float32 to float16
if (from_ty.is_float() && target_ty.is_float16()) {
if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_float16()) {
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__float22half2_rn", "float2", "half2");
Expand All @@ -1013,7 +1013,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from bfloat16 to float32
if (from_ty.is_bfloat16() && target_ty.is_float()) {
if (from_ty.is_bfloat16() && target_ty.is_float() && target_ty.bits() == 32) {
// Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__bfloat1622float2", "__nv_bfloat162", "float2", "",
Expand All @@ -1023,7 +1023,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from float32 to bfloat16
if (from_ty.is_float() && target_ty.is_bfloat16()) {
if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_bfloat16()) {
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__float22bfloat162_rn", "float2", "__nv_bfloat162",
Expand All @@ -1033,7 +1033,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from float32 to float8 (E4M3/E5M2)
if (from_ty.is_float() && tl::IsCudaVectorizableFP8(target_ty)) {
if (from_ty.is_float() && from_ty.bits() == 32 &&
tl::IsCudaVectorizableFP8(target_ty)) {
bool target_type_is_e4m3 =
target_ty.is_float8_e4m3() || target_ty.is_float8_e4m3fn();
std::string type_suffix = target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2";
Expand Down Expand Up @@ -1101,6 +1102,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}

// Handle conversion from double to float4 (E2M1)
if (from_ty.is_float() && from_ty.bits() == 64 &&
target_ty.is_float4_e2m1fn()) {
// Use __tl_cvt_double2_to_fp4x2 for vectorized conversion (double2 ->
// fp4x2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_double2_to_fp4x2", "double2", "uint8_t", "",
false, true);
return;
}
}

// Handle conversion from float4 (E2M1) to double
if (from_ty.is_float4_e2m1fn() && target_ty.is_float() &&
target_ty.bits() == 64) {
// Use __tl_cvt_fp4x2_to_double2 for vectorized conversion (fp4x2 ->
// double2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_fp4x2_to_double2", "uint8_t", "double2", "",
true, false);
return;
}
}

// Handle conversion from bfloat16 to float4 (E2M1)
if (from_ty.is_bfloat16() && target_ty.is_float4_e2m1fn()) {
// Use __tl_cvt_bfloat162_to_fp4x2 for vectorized conversion (bfloat162 ->
// fp4x2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_bfloat162_to_fp4x2", "__nv_bfloat162",
"uint8_t", "", false, true);
return;
}
}

// Handle conversion from float4 (E2M1) to bfloat16
if (from_ty.is_float4_e2m1fn() && target_ty.is_bfloat16()) {
// Use __tl_cvt_fp4x2_to_bfloat162 for vectorized conversion (fp4x2 ->
// bfloat162)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_fp4x2_to_bfloat162", "uint8_t",
"__nv_bfloat162", "", true, false);
return;
}
}

// Fallback: elementwise cast
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
Expand Down
134 changes: 102 additions & 32 deletions src/tl_templates/cuda/cuda_fp4.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,52 +154,122 @@ TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t(
return result;
}

// ============================================================================
// FP4 <-> Half Precision Conversions
// ============================================================================
// https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__FP4__MISC.html

// fp4_e2m1 -> half
TL_DEVICE __half __tl_cvt_fp4_to_half(const __nv_fp4_storage_t src) {
__half_raw raw = __nv_cvt_fp4_to_halfraw(src, __NV_E2M1);
__half result;
result = *reinterpret_cast<__half *>(&raw);
return result;
}

// fp4_e2m1x2 (1 byte) -> half2
// Uses PTX cvt.rn.f16x2.e2m1x2 instruction
TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const uint8_t src) {
half2 out;
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(&out);
uint16_t src_packed = static_cast<uint16_t>(src);
asm volatile("{\n"
".reg .b8 byte0, byte1;\n"
"mov.b16 {byte0, byte1}, %1;\n"
"cvt.rn.f16x2.e2m1x2 %0, byte0;\n"
"}\n"
: "=r"(*out_ptr)
: "h"(src_packed));
return out;
TL_DEVICE half2 __tl_cvt_fp4x2_to_half2(const __nv_fp4x2_storage_t src) {
__half2_raw raw = __nv_cvt_fp4x2_to_halfraw2(src, __NV_E2M1);
half2 result;
result = *reinterpret_cast<half2 *>(&raw);
return result;
}

// half -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_half_to_fp4(const __half src) {
__half_raw raw = *reinterpret_cast<const __half_raw *>(&src);
return __nv_cvt_halfraw_to_fp4(raw, __NV_E2M1, cudaRoundZero);
}

// half2 -> fp4_e2m1x2 (1 byte)
TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_half2_to_fp4x2(const half2 src) {
__half2_raw raw = *reinterpret_cast<const __half2_raw *>(&src);
return __nv_cvt_halfraw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero);
}

// ============================================================================
// FP4 <-> Float Conversions
// ============================================================================

// fp4_e2m1 -> float
TL_DEVICE float __tl_cvt_fp4_to_float(const __nv_fp4_storage_t src) {
return __half2float(__tl_cvt_fp4_to_half(src));
}

// fp4_e2m1x2 (1 byte) -> float2
TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const uint8_t src) {
TL_DEVICE float2 __tl_cvt_fp4x2_to_float2(const __nv_fp4x2_storage_t src) {
half2 tmp = __tl_cvt_fp4x2_to_half2(src);
float2 result;
result.x = __half2float(tmp.x);
result.y = __half2float(tmp.y);
return result;
}

// half2 -> fp4_e2m1x2 (1 byte)
// Uses PTX cvt.rn.satfinite.e2m1x2.f16x2 instruction
TL_DEVICE uint8_t __tl_cvt_half2_to_fp4x2(const half2 src) {
uint16_t out;
uint32_t const *src_ptr = reinterpret_cast<uint32_t const *>(&src);
asm volatile("{\n"
".reg .b8 result_byte;\n"
"cvt.rn.satfinite.e2m1x2.f16x2 result_byte, %1;\n"
"mov.b16 %0, {result_byte, 0};\n"
"}\n"
: "=h"(out)
: "r"(*src_ptr));
return static_cast<uint8_t>(out);
// float -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_float_to_fp4(const float src) {
return __nv_cvt_float_to_fp4(src, __NV_E2M1, cudaRoundZero);
}

// float2 -> fp4_e2m1x2 (1 byte)
TL_DEVICE uint8_t __tl_cvt_float2_to_fp4x2(const float2 src) {
half2 tmp;
tmp.x = __float2half(src.x);
tmp.y = __float2half(src.y);
return __tl_cvt_half2_to_fp4x2(tmp);
TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_float2_to_fp4x2(const float2 src) {
return __nv_cvt_float2_to_fp4x2(src, __NV_E2M1, cudaRoundZero);
}

// ============================================================================
// FP4 <-> Double Conversions
// ============================================================================

// fp4_e2m1 -> double
TL_DEVICE double __tl_cvt_fp4_to_double(const __nv_fp4_storage_t src) {
return static_cast<double>(__tl_cvt_fp4_to_float(src));
}

// fp4_e2m1x2 -> double2
TL_DEVICE double2 __tl_cvt_fp4x2_to_double2(const __nv_fp4x2_storage_t src) {
float2 tmp = __tl_cvt_fp4x2_to_float2(src);
double2 result;
result.x = static_cast<double>(tmp.x);
result.y = static_cast<double>(tmp.y);
return result;
}

// double -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_double_to_fp4(const double src) {
return __nv_cvt_double_to_fp4(src, __NV_E2M1, cudaRoundZero);
}

// double2 -> fp4_e2m1x2
TL_DEVICE __nv_fp4x2_storage_t __tl_cvt_double2_to_fp4x2(const double2 src) {
return __nv_cvt_double2_to_fp4x2(src, __NV_E2M1, cudaRoundZero);
}

// ============================================================================
// FP4 <-> BFloat16 Conversions
// ============================================================================

// fp4_e2m1 -> bfloat16
TL_DEVICE __nv_bfloat16 __tl_cvt_fp4_to_bfloat16(const __nv_fp4_storage_t src) {
return __float2bfloat16(__tl_cvt_fp4_to_float(src));
}

// fp4_e2m1x2 -> bfloat162
TL_DEVICE __nv_bfloat162
__tl_cvt_fp4x2_to_bfloat162(const __nv_fp4x2_storage_t src) {
float2 tmp = __tl_cvt_fp4x2_to_float2(src);
return __floats2bfloat162_rn(tmp.x, tmp.y);
}

// bfloat16 -> fp4_e2m1
TL_DEVICE __nv_fp4_storage_t __tl_cvt_bfloat16_to_fp4(const __nv_bfloat16 src) {
__nv_bfloat16_raw raw = *reinterpret_cast<const __nv_bfloat16_raw *>(&src);
return __nv_cvt_bfloat16raw_to_fp4(raw, __NV_E2M1, cudaRoundZero);
}

// bfloat162 -> fp4_e2m1x2
TL_DEVICE __nv_fp4x2_storage_t
__tl_cvt_bfloat162_to_fp4x2(const __nv_bfloat162 src) {
__nv_bfloat162_raw raw = *reinterpret_cast<const __nv_bfloat162_raw *>(&src);
return __nv_cvt_bfloat16raw2_to_fp4x2(raw, __NV_E2M1, cudaRoundZero);
}

#endif
57 changes: 36 additions & 21 deletions src/transform/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,7 @@ void ArgBinder::BindDLTensors(

// Scan buffer shape for symbolic variables
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
if (buffer->dtype.bits() < 8) {
break;
}

Expand Down Expand Up @@ -524,21 +522,40 @@ void ArgBinder::BindDLTensors(
cond =
cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
}
// Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for
// FP4).
if (buffer->dtype.is_float4()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
// For FP4, we pack 2 elements per byte, but we still use same lanes at
// storage level Accept int8 with same lanes as the fp4 type
PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok);
cond = cond || int8_ok;
// Allow with bits < 8 to match any type with the same total bit count at
// runtime (PyTorch uses int8 as storage for FP4).
bool data_is_subtype = buffer->dtype.bits() < 8;
if (data_is_subtype) {
// Get the pre-created shape buffer for reading runtime shape
Buffer buf_shape = shape_buffer_map[arg_name];

// Calculate expected total bits using compile-time buffer->shape
PrimExpr expect_total_bits =
cast(DataType::UInt(64), expect_bits) *
cast(DataType::UInt(64), expect_lanes) *
cast(DataType::UInt(64),
buffer->shape.empty()
? make_const(DataType::UInt(64), 1)
: foldl([](PrimExpr a, PrimExpr b, Span) { return a * b; },
make_const(DataType::UInt(64), 1), buffer->shape));

// Calculate actual total bits using runtime shape from DLTensor
PrimExpr actual_total_bits = cast(DataType::UInt(64), v_type_bits) *
cast(DataType::UInt(64), v_type_lanes);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
PrimExpr dim_val =
cast(DataType::UInt(64),
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
actual_total_bits = actual_total_bits * dim_val;
}

PrimExpr bits_match = (actual_total_bits == expect_total_bits);
BinderAddAssert(&analyzer_, bits_match,
arg_name + " is a subtype, but total bits mismatch",
&asserts_, is_null);
}
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) {
if (!data_is_subtype) {
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch
// occurs. Only issue the call when handle is non-NULL and cond is false.
ffi::Array<PrimExpr> packed_args;
Expand Down Expand Up @@ -578,9 +595,7 @@ void ArgBinder::BindDLTensors(
for (size_t k = 0; k < buffer->shape.size(); ++k) {
// These packed-bit dtype shapes were not bound in the original
// implementation, so we just use them as is.
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
if (data_is_subtype) {
break;
}

Expand Down Expand Up @@ -925,4 +940,4 @@ void ArgBinder::BindDLTensors(
}

} // namespace tl
} // namespace tvm
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,18 @@ def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes):
@pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes",
[
# FP4 <-> Half
(T.float4_e2m1fn, T.float16, "__tl_cvt_fp4x2_to_half2", 2),
(T.float16, T.float4_e2m1fn, "__tl_cvt_half2_to_fp4x2", 2),
# FP4 <-> Float
(T.float4_e2m1fn, T.float32, "__tl_cvt_fp4x2_to_float2", 2),
(T.float32, T.float4_e2m1fn, "__tl_cvt_float2_to_fp4x2", 2),
# FP4 <-> Double
(T.float4_e2m1fn, T.float64, "__tl_cvt_fp4x2_to_double2", 2),
(T.float64, T.float4_e2m1fn, "__tl_cvt_double2_to_fp4x2", 2),
# FP4 <-> BFloat16
(T.float4_e2m1fn, T.bfloat16, "__tl_cvt_fp4x2_to_bfloat162", 2),
(T.bfloat16, T.float4_e2m1fn, "__tl_cvt_bfloat162_to_fp4x2", 2),
],
)
def test_vectorized_cast_fp4(src_dtype, dst_dtype, check_str, lanes):
Expand Down
9 changes: 9 additions & 0 deletions tilelang/engine/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def torch_dtype(self) -> torch.dtype:
"""
return T.dtype(self.dtype).as_torch()

def tilelang_dtype(self) -> T.dtype:
"""
Converts the TVM DataType to TileLang dtype.

Returns:
T.dtype: Corresponding TileLang dtype
"""
return T.dtype(self.dtype)


@dataclass
class CompiledArtifact:
Expand Down
Loading
Loading