diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index a63297c3579e..f021df56c05b 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -291,3 +291,57 @@ def test_invalid_backend(): ): # Invalid backend name should raise ValueError when creating enum AttentionConfig(backend=AttentionBackendEnum["INVALID"]) + + +@pytest.mark.parametrize( + "backend_name,flash_attn_version,should_succeed", + [ + ("FLASH_ATTN", 3, True), # FA3 supports per-head quant scales + ("FLASH_ATTN", 2, False), # FA2 does not support per-head quant scales + ("FLASHINFER", None, False), # FlashInfer does not support + ("FLEX_ATTENTION", None, False), # Flex does not support + ], +) +def test_per_head_quant_scales_backend_selection( + backend_name: str, flash_attn_version: int | None, should_succeed: bool +): + """Test backend selection when use_per_head_quant_scales=True.""" + # Clear cache to ensure fresh backend selection + _cached_get_attn_backend.cache_clear() + + attention_config = AttentionConfig( + backend=AttentionBackendEnum[backend_name], + flash_attn_version=flash_attn_version, + ) + vllm_config = VllmConfig(attention_config=attention_config) + + with ( + set_current_vllm_config(vllm_config), + patch("vllm.platforms.current_platform", CudaPlatform()), + ): + if backend_name == "FLASH_ATTN" and flash_attn_version == 3: + if not torch.cuda.is_available(): + pytest.skip("FA3 requires CUDA") + capability = torch.cuda.get_device_capability() + if capability[0] != 9: + pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs") + + if should_succeed: + backend = get_attn_backend( + head_size=128, + dtype=torch.float16, + kv_cache_dtype="fp8", + block_size=64, + use_per_head_quant_scales=True, + ) + assert backend.get_name() == backend_name + else: + with pytest.raises(ValueError) as exc_info: + get_attn_backend( + head_size=128, + dtype=torch.float16, + kv_cache_dtype="fp8", + block_size=64, + use_per_head_quant_scales=True, + ) + assert backend_name in str(exc_info.value) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 8c3ff3cc4df7..b5647fe133d5 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -229,13 +229,20 @@ def __init__( calculate_kv_scales = False # llm-compressor mdls need to set cache_dtype to "fp8" manually. - if getattr(quant_config, "kv_cache_scheme", None) is not None: + kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None) + if kv_cache_scheme is not None: kv_cache_dtype = "fp8" calculate_kv_scales = False if cache_config is not None: cache_config.cache_dtype = "fp8" cache_config.calculate_kv_scales = False + # Check if per-head quant scales are required based on kv_cache_scheme + use_per_head_quant_scales = ( + kv_cache_scheme is not None + and kv_cache_scheme.get("strategy") == "attn_head" + ) + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( kv_cache_dtype, vllm_config.model_config ) @@ -272,6 +279,7 @@ def __init__( use_mla=False, has_sink=self.has_sink, use_mm_prefix=self.use_mm_prefix, + use_per_head_quant_scales=use_per_head_quant_scales, attn_type=attn_type, ) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 9b0fb5089c56..00a17596a52a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -985,14 +985,7 @@ def create_weights(self, layer: torch.nn.Module): self.quant_config.kv_cache_scheme["strategy"] ) - if strategy == QuantizationStrategy.ATTN_HEAD: - assert layer.impl.supports_per_head_quant_scales, ( - f"Layer {layer.__class__.__name__} with implementation " - f"{layer.impl.__class__.__name__} does not support per-head scales." - ) - n_scales = int(layer.num_kv_heads) - else: - n_scales = 1 + n_scales = int(layer.num_kv_heads) if strategy == "attn_head" else 1 layer.k_scale = torch.nn.Parameter( torch.ones(n_scales, requires_grad=False, dtype=torch.float32) diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 9c004d7724dd..94f155c31d08 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -187,6 +187,10 @@ def supports_mm_prefix(cls) -> bool: def is_sparse(cls) -> bool: return False + @classmethod + def supports_per_head_quant_scales(cls) -> bool: + return False + @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """Check if backend supports a given attention type. @@ -225,6 +229,7 @@ def validate_configuration( has_sink: bool, use_sparse: bool, use_mm_prefix: bool, + use_per_head_quant_scales: bool, device_capability: "DeviceCapability", attn_type: str, ) -> list[str]: @@ -253,6 +258,8 @@ def validate_configuration( invalid_reasons.append("sparse not supported") else: invalid_reasons.append("non-sparse not supported") + if use_per_head_quant_scales and not cls.supports_per_head_quant_scales(): + invalid_reasons.append("per-head quant scales not supported") if not cls.supports_compute_capability(device_capability): invalid_reasons.append("compute capability not supported") if not cls.supports_attn_type(attn_type): @@ -635,7 +642,6 @@ class AttentionImplBase(ABC, Generic[T]): # TODO add support to more backends: # https://github.com/vllm-project/vllm/issues/25584 supports_quant_query_input: bool = False - supports_per_head_quant_scales: bool = False dcp_world_size: int dcp_rank: int diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ecd1b274c8ce..d903bd89ce15 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -95,6 +95,11 @@ def supports_attn_type(cls, attn_type: str) -> bool: AttentionType.ENCODER_DECODER, ) + @classmethod + def supports_per_head_quant_scales(cls) -> bool: + fa_version = get_flash_attn_version() + return fa_version is not None and fa_version >= 3 + @staticmethod def get_impl_cls() -> type["FlashAttentionImpl"]: return FlashAttentionImpl @@ -595,11 +600,6 @@ def __init__( ) self.supports_quant_query_input = True - self.supports_per_head_quant_scales = ( - self.vllm_flash_attn_version >= 3 - if self.vllm_flash_attn_version is not None - else False - ) def forward( self, diff --git a/vllm/v1/attention/selector.py b/vllm/v1/attention/selector.py index 9580c1d5f355..48a86655cf87 100644 --- a/vllm/v1/attention/selector.py +++ b/vllm/v1/attention/selector.py @@ -27,6 +27,7 @@ class AttentionSelectorConfig(NamedTuple): has_sink: bool = False use_sparse: bool = False use_mm_prefix: bool = False + use_per_head_quant_scales: bool = False attn_type: str = AttentionType.DECODER def __repr__(self): @@ -39,6 +40,7 @@ def __repr__(self): f"has_sink={self.has_sink}, " f"use_sparse={self.use_sparse}, " f"use_mm_prefix={self.use_mm_prefix}, " + f"use_per_head_quant_scales={self.use_per_head_quant_scales}, " f"attn_type={self.attn_type})" ) @@ -52,6 +54,7 @@ def get_attn_backend( has_sink: bool = False, use_sparse: bool = False, use_mm_prefix: bool = False, + use_per_head_quant_scales: bool = False, attn_type: str | None = None, num_heads: int | None = None, ) -> type[AttentionBackend]: @@ -77,6 +80,7 @@ def get_attn_backend( has_sink=has_sink, use_sparse=use_sparse, use_mm_prefix=use_mm_prefix, + use_per_head_quant_scales=use_per_head_quant_scales, attn_type=attn_type or AttentionType.DECODER, )