Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class ncclDataTypeEnum:
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
ncclFloat8e4m3 = 10
ncclNumTypes = 11

@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
Expand All @@ -92,9 +93,12 @@ def from_torch(cls, dtype: torch.dtype) -> int:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
if dtype == torch.float8_e4m3fn:
return cls.ncclFloat8e4m3
raise ValueError(
f"Unsupported dtype {dtype}: should be one of "
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16."
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
" float8e4m3."
)
Comment on lines +96 to 102
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The codebase, for example in vllm/model_executor/layers/quantization/utils/fp8_utils.py, seems to use both torch.float8_e4m3fn and torch.float8_e4m3fnuz. This function should handle both types to avoid ValueError during collective communication operations with torch.float8_e4m3fnuz tensors. The error message is also updated for clarity.

For more complete FP8 support, you might also consider adding torch.float8_e5m2. This would involve adding ncclFloat8e5m2 to ncclDataTypeEnum and handling torch.float8_e5m2 in this method.

Suggested change
if dtype == torch.float8_e4m3fn:
return cls.ncclFloat8e4m3
raise ValueError(
f"Unsupported dtype {dtype}: should be one of "
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16."
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
" float8e4m3."
)
if dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz):
return cls.ncclFloat8e4m3
raise ValueError(
f"Unsupported dtype {dtype}: should be one of "
"int8, uint8, int32, int64, float16, float32, float64, bfloat16, "
"float8_e4m3fn, "
"float8_e4m3fnuz."
)



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def prepare(
a1q_scale = None

if is_nvfp4 and a1q_scale is not None:
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)

return a1q, a1q_scale, None, topk_ids, topk_weights
Expand Down