Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/moe/marlin_moe_wna16/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
if arch == 89 or arch // 10 == 12:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
Expand Down
4 changes: 2 additions & 2 deletions csrc/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
major_capability == 12,
"Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than "
"Marlin W4A16 on other devices).");
}

Expand Down
63 changes: 63 additions & 0 deletions csrc/quantization/fp4/nvfp4_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,56 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;

namespace vllm {

// Software E2M1 conversion for architectures without hardware
// cvt.rn.satfinite.e2m1x2.f32 (SM12x lacks this SM100-only instruction).
// Uses round-to-nearest-even (IEEE 754) to match hardware behavior:
// at midpoints, ties break to the value with an even integer code.
// E2M1 representable values and codes:
// 0.0(0) 0.5(1) 1.0(2) 1.5(3) 2.0(4) 3.0(5) 4.0(6) 6.0(7)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300

__device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) {
uint8_t sign = (__float_as_uint(v) >> 31) & 1;
float av = fabsf(v);
uint8_t e2m1;
// Midpoint tie-breaking: <= rounds to lower (even) code, < rounds to upper.
if (av <= 0.25f)
e2m1 = 0; // 0.0; midpoint 0.25 → code 0 (even)
else if (av < 0.75f)
e2m1 = 1; // 0.5; midpoint 0.75 → code 2 (even, next branch)
else if (av <= 1.25f)
e2m1 = 2; // 1.0; midpoint 1.25 → code 2 (even)
else if (av < 1.75f)
e2m1 = 3; // 1.5; midpoint 1.75 → code 4 (even, next branch)
else if (av <= 2.5f)
e2m1 = 4; // 2.0; midpoint 2.5 → code 4 (even)
else if (av < 3.5f)
e2m1 = 5; // 3.0; midpoint 3.5 → code 6 (even, next branch)
else if (av <= 5.0f)
e2m1 = 6; // 4.0; midpoint 5.0 → code 6 (even)
else
e2m1 = 7; // 6.0 (satfinite)
return (sign << 3) | e2m1;
Comment on lines +48 to +69
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there really not a more efficient implementation than this? This seems like it would be quite slow

}
Comment on lines +48 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The software implementation of sw_float_to_e2m1 does not correctly emulate the round-to-nearest-even behavior of the cvt.rn.satfinite.e2m1x2.f32 PTX instruction for midpoint values. The current implementation uses round-half-up, which will lead to correctness issues. For example, for a value of 0.25, which is a midpoint between 0.0 (rep 0) and 0.5 (rep 1), it should round to 0.0 because its integer representation 0 is even. The current code rounds it to 0.5. This is a critical issue that will cause divergence from the hardware implementation.

__device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) {
  uint8_t sign = (__float_as_uint(v) >> 31) & 1;
  float av = fabsf(v);
  uint8_t e2m1;
  // E2M1 representable values (integer representation):
  // 0.0 (0), 0.5 (1), 1.0 (2), 1.5 (3), 2.0 (4), 3.0 (5), 4.0 (6), 6.0 (7)
  if (av <= 0.25f) e2m1 = 0;      // Midpoint 0.25 rounds to 0 (even)
  else if (av < 0.75f) e2m1 = 1;
  else if (av <= 1.25f) e2m1 = 2; // Midpoints 0.75, 1.25 round to 2 (even)
  else if (av < 1.75f) e2m1 = 3;
  else if (av <= 2.5f) e2m1 = 4;  // Midpoints 1.75, 2.5 round to 4 (even)
  else if (av < 3.5f) e2m1 = 5;
  else if (av <= 5.0f) e2m1 = 6;  // Midpoints 3.5, 5.0 round to 6 (even)
  else e2m1 = 7;
  return (sign << 3) | e2m1;
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code already implements round-to-nearest-even correctly. The <= comparisons at midpoints (0.25, 1.25, 2.5, 5.0) direct ties to the lower even codes (0, 2, 4, 6), while the < comparisons at midpoints (0.75, 1.75, 3.5) direct ties to the higher even codes (2, 4, 6 in the next branch). This matches the hardware cvt.rn.satfinite.e2m1x2.f32 behavior — verified empirically on SM121a by comparing hardware vs software outputs for all E2M1 midpoint values.


// Pack two E2M1 values into one byte (matches cvt.rn.satfinite.e2m1x2.f32
// layout: hi in upper nibble, lo in lower nibble).
__device__ __forceinline__ uint8_t sw_e2m1x2_from_f32(float hi, float lo) {
return (sw_float_to_e2m1(hi) << 4) | sw_float_to_e2m1(lo);
}

// Pack 8 float values (as 4 float2) into a uint32_t of E2M1 values.
__device__ __forceinline__ uint32_t sw_fp32_vec8_to_e2m1(const float2* array) {
uint8_t b0 = sw_e2m1x2_from_f32(array[0].y, array[0].x);
uint8_t b1 = sw_e2m1x2_from_f32(array[1].y, array[1].x);
uint8_t b2 = sw_e2m1x2_from_f32(array[2].y, array[2].x);
uint8_t b3 = sw_e2m1x2_from_f32(array[3].y, array[3].x);
return (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) |
((uint32_t)b3 << 24);
}

#endif // SM12x software E2M1

template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
Expand Down Expand Up @@ -70,6 +120,9 @@ inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth redefining this as new def like "SOFTWARE_E2M1_CONVERT"

val = sw_fp32_vec8_to_e2m1(reinterpret_cast<const float2*>(array));
#else
asm volatile(
"{\n"
".reg .b8 byte0;\n"
Expand All @@ -85,12 +138,16 @@ inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
#endif
return val;
}

// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) {
uint32_t val;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300
val = sw_fp32_vec8_to_e2m1(array);
#else
asm volatile(
"{\n"
".reg .b8 byte0;\n"
Expand All @@ -106,6 +163,7 @@ __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) {
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
#endif
return val;
}

Expand All @@ -117,6 +175,10 @@ using fp4_packed_t = std::conditional_t<CVT_FP4_PACK16, u32x2, uint32_t>;

__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) {
u32x2 out;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300
out.lo = sw_fp32_vec8_to_e2m1(array);
out.hi = sw_fp32_vec8_to_e2m1(array + 4);
#else
asm volatile(
"{\n"
".reg .b8 b0;\n"
Expand All @@ -143,6 +205,7 @@ __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) {
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y),
"f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y),
"f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y));
#endif
return out;
}

Expand Down
6 changes: 3 additions & 3 deletions csrc/quantization/marlin/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
if arch == 89 or arch // 10 == 12:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct cutlass_3x_gemm_sm120_custom {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule, void>::CollectiveOp;

using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};

Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,8 @@ def is_invalid(
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
and not current_platform.is_device_capability(89)
and not current_platform.is_device_capability_family(120)
):
continue
args = sub_case + (m, n, k) + case[4:]
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/quantization/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ def is_invalid(
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
and not current_platform.is_device_capability(89)
and not current_platform.is_device_capability_family(120)
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,9 @@ def get_marlin_input_dtype(prefix: str | None = None):
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
) and not current_platform.is_device_capability_family(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"Marlin W4A8-FP8 only support SM89 or SM12x device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
Expand Down
Loading