Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,8 +921,11 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
// fp4_e2_8_t
if (t.lanes() >= 8)
os << "." << access[(i % 8) / 4];
// fp4_e2_4_t or fp4_e2_2_t
os << "." << access[i % 4];
// fp4_e2_4_t -> fp4_e2_2_t member
if (t.lanes() >= 4)
os << "." << access[(i % 4) / 2];
// fp4_e2_2_t -> method call x() or y()
os << "." << access[i % 2] << "()";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -1040,8 +1043,11 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
// fp4_e2_8_t
if (t.lanes() >= 8)
stream << "." << access[(i % 8) / 4];
// fp4_e2_4_t or fp4_e2_2_t
stream << "." << access[i % 4] << " = " << value << ";\n";
// fp4_e2_4_t -> fp4_e2_2_t member
if (t.lanes() >= 4)
stream << "." << access[(i % 4) / 2];
// fp4_e2_2_t -> set_x() or set_y()
stream << ".set_" << access[i % 2] << "(" << value << ");\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
Expand Down
51 changes: 29 additions & 22 deletions src/tl_templates/cuda/cuda_fp4.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,36 @@ struct fp4_e2_t {
TL_DEVICE operator __half() const { return __half(float(*this)); }
};

using fp4_e2x2_t = __nv_fp4x2_e2m1;
using fp4_e2x4_t = __nv_fp4x4_e2m1;
class fp4_e2_2_t {
public:
__nv_fp4x2_storage_t __x;

struct fp4_e2x8_t {
fp4_e2_t data[8];
};
TL_DEVICE fp4_e2_2_t() = default;
TL_DEVICE fp4_e2_2_t(__nv_fp4x2_storage_t data) : __x(data) {}
TL_DEVICE fp4_e2_2_t(__nv_fp4x2_e2m1 data) : __x(data.__x) {}

struct fp4_e2x16_t {
fp4_e2_t data[16];
};
// Get low 4 bits (first fp4)
TL_DEVICE fp4_e2_t x() const {
return fp4_e2_t(__nv_fp4_storage_t(__x & 0x0F));
}

struct __CUDA_ALIGN__(1) fp4_e2_2_t {
fp4_e2_t x;
fp4_e2_t y;
// Get high 4 bits (second fp4)
TL_DEVICE fp4_e2_t y() const {
return fp4_e2_t(__nv_fp4_storage_t((__x >> 4) & 0x0F));
}

// Set low 4 bits (first fp4)
TL_DEVICE void set_x(fp4_e2_t val) { __x = (__x & 0xF0) | (val.__x & 0x0F); }

// Set high 4 bits (second fp4)
TL_DEVICE void set_y(fp4_e2_t val) {
__x = (__x & 0x0F) | ((val.__x & 0x0F) << 4);
}
};

struct __CUDA_ALIGN__(2) fp4_e2_4_t {
fp4_e2_t x;
fp4_e2_t y;
fp4_e2_t z;
fp4_e2_t w;
struct __CUDA_ALIGN__(4) fp4_e2_4_t {
fp4_e2_2_t x;
fp4_e2_2_t y;
};

struct __CUDA_ALIGN__(4) fp4_e2_8_t {
Expand Down Expand Up @@ -97,20 +106,18 @@ struct __CUDA_ALIGN__(32) fp4_e2_64_t {

// Pack two fp4_e2_t values.
TL_DEVICE fp4_e2_2_t make_fp4_e2_2_t(fp4_e2_t x, fp4_e2_t y) {
__nv_fp4x2_storage_t packed = (x.__x & 0x0F) | ((y.__x & 0x0F) << 4);
fp4_e2_2_t result;
result.x = x;
result.y = y;
result.__x = packed;
return result;
}

// Pack four fp4_e2_t values.
TL_DEVICE fp4_e2_4_t make_fp4_e2_4_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2,
fp4_e2_t x3) {
fp4_e2_4_t result;
result.x = x0;
result.y = x1;
result.z = x2;
result.w = x3;
result.x = make_fp4_e2_2_t(x0, x1);
result.y = make_fp4_e2_2_t(x2, x3);
return result;
}

Expand Down
4 changes: 3 additions & 1 deletion testing/python/language/test_tilelang_language_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.
source = kernel.get_kernel_source()
assert "fp4_e2_t" in source
# For FP4, use same shape as kernel expects, since int8 is used as storage type
dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8)
dummy_input = torch.randint(0, 100, (M, N // 2), device="cuda", dtype=torch.int8)
output = kernel(dummy_input)
if src_dtype == dst_dtype:
assert torch.allclose(output.view(torch.int8), dummy_input)
assert output is not None


Expand Down
Loading