-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
fix: Software E2M1 conversion for SM12x NVFP4 activation quantization #35947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7308dc
473fb13
8ce3605
df799d0
dc818d3
fb63fc9
da23cc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,56 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; | |
|
|
||
| namespace vllm { | ||
|
|
||
| // Software E2M1 conversion for architectures without hardware | ||
| // cvt.rn.satfinite.e2m1x2.f32 (SM12x lacks this SM100-only instruction). | ||
| // Uses round-to-nearest-even (IEEE 754) to match hardware behavior: | ||
| // at midpoints, ties break to the value with an even integer code. | ||
| // E2M1 representable values and codes: | ||
| // 0.0(0) 0.5(1) 1.0(2) 1.5(3) 2.0(4) 3.0(5) 4.0(6) 6.0(7) | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 | ||
|
|
||
| __device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) { | ||
| uint8_t sign = (__float_as_uint(v) >> 31) & 1; | ||
| float av = fabsf(v); | ||
| uint8_t e2m1; | ||
| // Midpoint tie-breaking: <= rounds to lower (even) code, < rounds to upper. | ||
| if (av <= 0.25f) | ||
| e2m1 = 0; // 0.0; midpoint 0.25 → code 0 (even) | ||
| else if (av < 0.75f) | ||
| e2m1 = 1; // 0.5; midpoint 0.75 → code 2 (even, next branch) | ||
| else if (av <= 1.25f) | ||
| e2m1 = 2; // 1.0; midpoint 1.25 → code 2 (even) | ||
| else if (av < 1.75f) | ||
| e2m1 = 3; // 1.5; midpoint 1.75 → code 4 (even, next branch) | ||
| else if (av <= 2.5f) | ||
| e2m1 = 4; // 2.0; midpoint 2.5 → code 4 (even) | ||
| else if (av < 3.5f) | ||
| e2m1 = 5; // 3.0; midpoint 3.5 → code 6 (even, next branch) | ||
| else if (av <= 5.0f) | ||
| e2m1 = 6; // 4.0; midpoint 5.0 → code 6 (even) | ||
| else | ||
| e2m1 = 7; // 6.0 (satfinite) | ||
| return (sign << 3) | e2m1; | ||
| } | ||
|
Comment on lines
+48
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The software implementation of
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current code already implements round-to-nearest-even correctly. The |
||
|
|
||
| // Pack two E2M1 values into one byte (matches cvt.rn.satfinite.e2m1x2.f32 | ||
| // layout: hi in upper nibble, lo in lower nibble). | ||
| __device__ __forceinline__ uint8_t sw_e2m1x2_from_f32(float hi, float lo) { | ||
| return (sw_float_to_e2m1(hi) << 4) | sw_float_to_e2m1(lo); | ||
| } | ||
|
|
||
| // Pack 8 float values (as 4 float2) into a uint32_t of E2M1 values. | ||
| __device__ __forceinline__ uint32_t sw_fp32_vec8_to_e2m1(const float2* array) { | ||
| uint8_t b0 = sw_e2m1x2_from_f32(array[0].y, array[0].x); | ||
| uint8_t b1 = sw_e2m1x2_from_f32(array[1].y, array[1].x); | ||
| uint8_t b2 = sw_e2m1x2_from_f32(array[2].y, array[2].x); | ||
| uint8_t b3 = sw_e2m1x2_from_f32(array[3].y, array[3].x); | ||
| return (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | | ||
| ((uint32_t)b3 << 24); | ||
| } | ||
|
|
||
| #endif // SM12x software E2M1 | ||
|
|
||
| template <typename Int> | ||
| __host__ __device__ inline Int round_up(Int x, Int y) { | ||
| static_assert(std::is_integral_v<Int>, | ||
|
|
@@ -70,6 +120,9 @@ inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m, | |
| // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). | ||
| inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { | ||
| uint32_t val; | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth redefining this as new def like "SOFTWARE_E2M1_CONVERT" |
||
| val = sw_fp32_vec8_to_e2m1(reinterpret_cast<const float2*>(array)); | ||
| #else | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .b8 byte0;\n" | ||
|
|
@@ -85,12 +138,16 @@ inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { | |
| : "=r"(val) | ||
| : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), | ||
| "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); | ||
| #endif | ||
| return val; | ||
| } | ||
|
|
||
| // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). | ||
| __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { | ||
| uint32_t val; | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 | ||
| val = sw_fp32_vec8_to_e2m1(array); | ||
| #else | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .b8 byte0;\n" | ||
|
|
@@ -106,6 +163,7 @@ __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { | |
| : "=r"(val) | ||
| : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), | ||
| "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); | ||
| #endif | ||
| return val; | ||
| } | ||
|
|
||
|
|
@@ -117,6 +175,10 @@ using fp4_packed_t = std::conditional_t<CVT_FP4_PACK16, u32x2, uint32_t>; | |
|
|
||
| __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { | ||
| u32x2 out; | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 | ||
| out.lo = sw_fp32_vec8_to_e2m1(array); | ||
| out.hi = sw_fp32_vec8_to_e2m1(array + 4); | ||
| #else | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .b8 b0;\n" | ||
|
|
@@ -143,6 +205,7 @@ __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { | |
| "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), | ||
| "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), | ||
| "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); | ||
| #endif | ||
| return out; | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there really not a more efficient implementation than this? This seems like it would be quite slow