Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 82 additions & 16 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
Expand All @@ -46,6 +47,35 @@
logger = init_logger(__name__)


def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
"""Returns whether the quantization method should load quantized weights."""
return quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
)


def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
"""Sets default quantization scales for the layer."""
if register_buffer:
layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
else:
layer._k_scale.fill_(1.0)
layer._v_scale.fill_(1.0)
layer._q_scale.fill_(1.0)
layer._prob_scale.fill_(1.0)

# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
layer._prob_scale_float = 1.0


def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
Expand Down Expand Up @@ -74,17 +104,21 @@ def _init_kv_cache_quant(
# with the model weights.
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = calculate_kv_scales
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)

# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
# Note [Register q/k/v/prob scales in state dict]
# When calling model.to(device), only parameters/buffers in state dict are
# moved. If not registering q/k/v/prob scales in state dict, there would
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
# on cpu.
# Registering in state dict means it interacts with weight loading. One edge
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
# (i.e., should_load_quant_weights(quant_method) == False).
# In this case, the checkpoint does not have the scales. We need to
# initialize the scales to 1.0 and update the scales after weight loading.
# This is espectially important when we load dummy weights first (providing
# wrong scales) and then load real weights (which misses scales and keeps the
# wrong scales from dummy load).
set_default_quant_scales(layer, register_buffer=True)

# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
Expand All @@ -93,9 +127,9 @@ def _init_kv_cache_quant(
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
):

# See [Note: Register q/k/v/prob scales in state dict]
if should_load_quant_weights(quant_method):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
Expand Down Expand Up @@ -169,10 +203,16 @@ def __init__(
assert num_heads % num_kv_heads == 0, (
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
)
self.quant_config = quant_config
self.layer_name = prefix

# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
self,
self.quant_config,
self.layer_name,
kv_cache_dtype,
calculate_kv_scales,
)

self.num_heads = num_heads
Expand Down Expand Up @@ -249,7 +289,6 @@ def __init__(
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type

if kv_sharing_target_layer_name is not None:
Expand Down Expand Up @@ -378,6 +417,17 @@ def extra_repr(self) -> str:
def process_weights_after_loading(self, act_dtype: torch.dtype):
self.impl.process_weights_after_loading(act_dtype)

# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method = (
self.quant_config.get_quant_method(self, prefix=self.layer_name)
if self.quant_config
else None
)
if not should_load_quant_weights(quant_method):
set_default_quant_scales(self, register_buffer=False)

def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend

Expand Down Expand Up @@ -453,10 +503,15 @@ def __init__(
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.quant_config = quant_config

# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
self,
self.quant_config,
self.layer_name,
kv_cache_dtype,
calculate_kv_scales,
)

dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -586,6 +641,17 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)

# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method = (
self.quant_config.get_quant_method(self, prefix=self.layer_name)
if self.quant_config
else None
)
if not should_load_quant_weights(quant_method):
set_default_quant_scales(self, register_buffer=False)

def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
Expand Down