Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,57 @@ def test_invalid_backend():
):
# Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])


@pytest.mark.parametrize(
"backend_name,flash_attn_version,should_succeed",
[
("FLASH_ATTN", 3, True), # FA3 supports per-head quant scales
("FLASH_ATTN", 2, False), # FA2 does not support per-head quant scales
("FLASHINFER", None, False), # FlashInfer does not support
("FLEX_ATTENTION", None, False), # Flex does not support
],
)
def test_per_head_quant_scales_backend_selection(
backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
"""Test backend selection when use_per_head_quant_scales=True."""
# Clear cache to ensure fresh backend selection
_cached_get_attn_backend.cache_clear()

attention_config = AttentionConfig(
backend=AttentionBackendEnum[backend_name],
flash_attn_version=flash_attn_version,
)
vllm_config = VllmConfig(attention_config=attention_config)

with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
):
if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
if not torch.cuda.is_available():
pytest.skip("FA3 requires CUDA")
capability = torch.cuda.get_device_capability()
if capability[0] != 9:
pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")

if should_succeed:
backend = get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend.get_name() == backend_name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend_name in str(exc_info.value)
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,20 @@ def __init__(
calculate_kv_scales = False

# llm-compressor mdls need to set cache_dtype to "fp8" manually.
if getattr(quant_config, "kv_cache_scheme", None) is not None:
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
if kv_cache_scheme is not None:
kv_cache_dtype = "fp8"
calculate_kv_scales = False
if cache_config is not None:
cache_config.cache_dtype = "fp8"
cache_config.calculate_kv_scales = False

# Check if per-head quant scales are required based on kv_cache_scheme
use_per_head_quant_scales = (
kv_cache_scheme is not None
and kv_cache_scheme.get("strategy") == "attn_head"
)

self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
Expand Down Expand Up @@ -272,6 +279,7 @@ def __init__(
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,14 +985,7 @@ def create_weights(self, layer: torch.nn.Module):
self.quant_config.kv_cache_scheme["strategy"]
)

if strategy == QuantizationStrategy.ATTN_HEAD:
assert layer.impl.supports_per_head_quant_scales, (
f"Layer {layer.__class__.__name__} with implementation "
f"{layer.impl.__class__.__name__} does not support per-head scales."
)
n_scales = int(layer.num_kv_heads)
else:
n_scales = 1
n_scales = int(layer.num_kv_heads) if strategy == "attn_head" else 1

layer.k_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def supports_mm_prefix(cls) -> bool:
def is_sparse(cls) -> bool:
return False

@classmethod
def supports_per_head_quant_scales(cls) -> bool:
return False

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
Expand Down Expand Up @@ -225,6 +229,7 @@ def validate_configuration(
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
use_per_head_quant_scales: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
Expand Down Expand Up @@ -253,6 +258,8 @@ def validate_configuration(
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if use_per_head_quant_scales and not cls.supports_per_head_quant_scales():
invalid_reasons.append("per-head quant scales not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
Expand Down Expand Up @@ -635,7 +642,6 @@ class AttentionImplBase(ABC, Generic[T]):
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
supports_per_head_quant_scales: bool = False

dcp_world_size: int
dcp_rank: int
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def supports_attn_type(cls, attn_type: str) -> bool:
AttentionType.ENCODER_DECODER,
)

@classmethod
def supports_per_head_quant_scales(cls) -> bool:
fa_version = get_flash_attn_version()
return fa_version is not None and fa_version >= 3

@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
Expand Down Expand Up @@ -595,11 +600,6 @@ def __init__(
)

self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)

def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AttentionSelectorConfig(NamedTuple):
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER

def __repr__(self):
Expand All @@ -39,6 +40,7 @@ def __repr__(self):
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type})"
)

Expand All @@ -52,6 +54,7 @@ def get_attn_backend(
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
use_per_head_quant_scales: bool = False,
attn_type: str | None = None,
num_heads: int | None = None,
) -> type[AttentionBackend]:
Expand All @@ -77,6 +80,7 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
)

Expand Down