diff --git a/vllm/envs.py b/vllm/envs.py index 01d8d8a2d2e0..cb4b54307033 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -133,6 +133,7 @@ VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False + VLLM_USE_NVFP4_CT_EMULATIONS: bool = False def get_default_cache_root(): @@ -918,6 +919,12 @@ def get_vllm_port() -> Optional[int]: # or bad hardware but it may add compute overhead. "VLLM_COMPUTE_NANS_IN_LOGITS": lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), + + # Controls whether or not emulations are used for NVFP4 + # generations on machines < 100 for compressed-tensors + # models + "VLLM_USE_NVFP4_CT_EMULATIONS": + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e5702c871cc9..d21abb2741a2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -13,6 +13,7 @@ QuantizationType) from pydantic import BaseModel +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -374,7 +375,8 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(): + if CompressedTensorsW4A4Fp4.cutlass_fp4_supported( + ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: logger.warning_once( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 32718972a627..ec1d4a6c0efa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -4,11 +4,14 @@ import torch from torch.nn.parameter import Parameter +import vllm.envs as envs from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + run_nvfp4_emulations) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -26,6 +29,8 @@ def __init__(self): @classmethod def get_min_capability(cls) -> int: + if envs.VLLM_USE_NVFP4_CT_EMULATIONS: + return 80 return 100 @classmethod @@ -129,6 +134,17 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if envs.VLLM_USE_NVFP4_CT_EMULATIONS: + out = run_nvfp4_emulations( + x=x, + input_global_scale=layer.input_global_scale, + weight=layer.weight, + weight_scale_swizzled=layer.weight_scale_swizzled, + weight_global_scale=layer.weight_global_scale) + if bias is not None: + out = out + bias + return out + output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]]