diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 259343f6d..903e6503e 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -921,8 +921,11 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, // fp4_e2_8_t if (t.lanes() >= 8) os << "." << access[(i % 8) / 4]; - // fp4_e2_4_t or fp4_e2_2_t - os << "." << access[i % 4]; + // fp4_e2_4_t -> fp4_e2_2_t member + if (t.lanes() >= 4) + os << "." << access[(i % 4) / 2]; + // fp4_e2_2_t -> method call x() or y() + os << "." << access[i % 2] << "()"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -1040,8 +1043,11 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, // fp4_e2_8_t if (t.lanes() >= 8) stream << "." << access[(i % 8) / 4]; - // fp4_e2_4_t or fp4_e2_2_t - stream << "." << access[i % 4] << " = " << value << ";\n"; + // fp4_e2_4_t -> fp4_e2_2_t member + if (t.lanes() >= 4) + stream << "." << access[(i % 4) / 2]; + // fp4_e2_2_t -> set_x() or set_y() + stream << ".set_" << access[i % 2] << "(" << value << ");\n"; } else { stream << vec << "." << access[i] << " = " << value << ";\n"; } diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index 93c844bb9..75dbc13ce 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -44,27 +44,36 @@ struct fp4_e2_t { TL_DEVICE operator __half() const { return __half(float(*this)); } }; -using fp4_e2x2_t = __nv_fp4x2_e2m1; -using fp4_e2x4_t = __nv_fp4x4_e2m1; +class fp4_e2_2_t { +public: + __nv_fp4x2_storage_t __x; -struct fp4_e2x8_t { - fp4_e2_t data[8]; -}; + TL_DEVICE fp4_e2_2_t() = default; + TL_DEVICE fp4_e2_2_t(__nv_fp4x2_storage_t data) : __x(data) {} + TL_DEVICE fp4_e2_2_t(__nv_fp4x2_e2m1 data) : __x(data.__x) {} -struct fp4_e2x16_t { - fp4_e2_t data[16]; -}; + // Get low 4 bits (first fp4) + TL_DEVICE fp4_e2_t x() const { + return fp4_e2_t(__nv_fp4_storage_t(__x & 0x0F)); + } -struct __CUDA_ALIGN__(1) fp4_e2_2_t { - fp4_e2_t x; - fp4_e2_t y; + // Get high 4 bits (second fp4) + TL_DEVICE fp4_e2_t y() const { + return fp4_e2_t(__nv_fp4_storage_t((__x >> 4) & 0x0F)); + } + + // Set low 4 bits (first fp4) + TL_DEVICE void set_x(fp4_e2_t val) { __x = (__x & 0xF0) | (val.__x & 0x0F); } + + // Set high 4 bits (second fp4) + TL_DEVICE void set_y(fp4_e2_t val) { + __x = (__x & 0x0F) | ((val.__x & 0x0F) << 4); + } }; -struct __CUDA_ALIGN__(2) fp4_e2_4_t { - fp4_e2_t x; - fp4_e2_t y; - fp4_e2_t z; - fp4_e2_t w; +struct __CUDA_ALIGN__(4) fp4_e2_4_t { + fp4_e2_2_t x; + fp4_e2_2_t y; }; struct __CUDA_ALIGN__(4) fp4_e2_8_t { @@ -97,9 +106,9 @@ struct __CUDA_ALIGN__(32) fp4_e2_64_t { // Pack two fp4_e2_t values. TL_DEVICE fp4_e2_2_t make_fp4_e2_2_t(fp4_e2_t x, fp4_e2_t y) { + __nv_fp4x2_storage_t packed = (x.__x & 0x0F) | ((y.__x & 0x0F) << 4); fp4_e2_2_t result; - result.x = x; - result.y = y; + result.__x = packed; return result; } @@ -107,10 +116,8 @@ TL_DEVICE fp4_e2_2_t make_fp4_e2_2_t(fp4_e2_t x, fp4_e2_t y) { TL_DEVICE fp4_e2_4_t make_fp4_e2_4_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, fp4_e2_t x3) { fp4_e2_4_t result; - result.x = x0; - result.y = x1; - result.z = x2; - result.w = x3; + result.x = make_fp4_e2_2_t(x0, x1); + result.y = make_fp4_e2_2_t(x2, x3); return result; } diff --git a/testing/python/language/test_tilelang_language_copy.py b/testing/python/language/test_tilelang_language_copy.py index d9d6659d1..7b22cc34e 100644 --- a/testing/python/language/test_tilelang_language_copy.py +++ b/testing/python/language/test_tilelang_language_copy.py @@ -168,8 +168,10 @@ def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T. source = kernel.get_kernel_source() assert "fp4_e2_t" in source # For FP4, use same shape as kernel expects, since int8 is used as storage type - dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8) + dummy_input = torch.randint(0, 100, (M, N // 2), device="cuda", dtype=torch.int8) output = kernel(dummy_input) + if src_dtype == dst_dtype: + assert torch.allclose(output.view(torch.int8), dummy_input) assert output is not None