diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ef4d28f220c7..7f132098e86e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -202,7 +202,8 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t block_stride, const int64_t page_stride, const int64_t head_stride, const int64_t key_stride, const int64_t value_stride, const int num_heads, const int head_size, - const int block_size, const float* k_scale, const float* v_scale) { + const int block_size, const float* k_scale, const float* v_scale, + const int kv_scale_stride) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -226,21 +227,23 @@ __global__ void reshape_and_cache_flash_kernel( // this is true for the NHD layout where `head_stride == head_size` const bool is_contiguous_heads = (head_stride == head_size); - float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; - float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4; - CopyWithScaleOp k_op{k_scale_val}; - CopyWithScaleOp v_op{v_scale_val}; - if (is_contiguous_heads) { - // NHD layout + + if (is_contiguous_heads && kv_scale_stride == 0) { + // NHD layout and k/v_scales are [1] (i.e. single scale for all heads) // kv cache: [num_blocks, block_size, num_heads, head_size] + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; + + CopyWithScaleOp k_op{k_scale_val}; + CopyWithScaleOp v_op{v_scale_val}; + vectorize_with_alignment(key_src, key_dst, n_elems, threadIdx.x, blockDim.x, k_op); - vectorize_with_alignment(value_src, value_dst, n_elems, threadIdx.x, blockDim.x, v_op); - } else { + // HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head) // HND layout: heads are strided, but each head_size segment is contiguous // kv cache: [num_blocks, num_heads, block_size, head_size] const int lane = threadIdx.x & 31; // 0..31 within warp @@ -256,6 +259,16 @@ __global__ void reshape_and_cache_flash_kernel( cache_t* __restrict__ v_dst_h = value_dst + static_cast(head) * head_stride; + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) + ? 0.f + : k_scale[head * kv_scale_stride]; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) + ? 0.f + : v_scale[head * kv_scale_stride]; + + CopyWithScaleOp k_op{k_scale_val}; + CopyWithScaleOp v_op{v_scale_val}; + // within each head, let the 32 threads of the warp perform the vector // copy vectorize_with_alignment(k_src_h, k_dst_h, head_size, lane, 32, @@ -605,7 +618,8 @@ void reshape_and_cache( slot_mapping.data_ptr(), block_stride, page_stride, \ head_stride, key_stride, value_stride, num_heads, head_size, \ block_size, reinterpret_cast(k_scale.data_ptr()), \ - reinterpret_cast(v_scale.data_ptr())); + reinterpret_cast(v_scale.data_ptr()), \ + kv_scale_stride); void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -614,8 +628,9 @@ void reshape_and_cache_flash( torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, // [1] or [num_heads] + torch::Tensor& v_scale) { // [1] or [num_heads] // NOTE(woosuk): In vLLM V1, key.size(0) can be different from // slot_mapping.size(0) because of padding for CUDA graphs. // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because @@ -638,6 +653,12 @@ void reshape_and_cache_flash( int64_t head_stride = key_cache.stride(2); TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + TORCH_CHECK(k_scale.sizes() == v_scale.sizes(), + "k_scale and v_scale must have the same shape"); + TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads, + "k_scale and v_scale must be of shape [1] or [num_heads]"); + int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0; + dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); diff --git a/docs/features/quantization/quantized_kvcache.md b/docs/features/quantization/quantized_kvcache.md index 586117272d3b..2c5bfd643946 100644 --- a/docs/features/quantization/quantized_kvcache.md +++ b/docs/features/quantization/quantized_kvcache.md @@ -1,162 +1,187 @@ # Quantized KV Cache -## FP8 KV Cache +## FP8 KV Cache Overview -Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput. +Efficient memory usage is crucial for working with large language models. Quantizing the KV (Key-Value) cache to FP8 format can significantly reduce its memory footprint. This optimization enables you to store more tokens in memory, leading to improved throughput and support for longer context windows. -### FP8 Formats +> **Note:** When using the Flash Attention 3 backend with FP8 KV cache, attention operations are also performed in the quantized (FP8) domain. In this configuration, queries are quantized to FP8 in addition to keys and values. -[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats: +### Supported FP8 KV-Cache Quantization Schemes -- E5M2 (5 exponent bits and 2 mantissa bits) -- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3) +vLLM supports two main quantization strategies for the FP8 KV-cache: -The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor. +- **Per-tensor quantization:** + A single scale is applied for each Q, K, and V tensor individually. (`q/k/v_scale = [1]`) +- **Per-attention-head quantization:** + Each scale corresponds to an attention head: `q_scale = [num_heads]`, `k/v_scale = [num_kv_heads]`. -### Current Limitations +> **Note:** +> Per-attention-head quantization is currently available **only with the Flash Attention backend** and requires the calibration pathway provided by **llm-compressor**. -For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel). +### Scale Calibration Approaches -### How FP8 KV Cache Works +You can configure how the quantization scales are computed in vLLM using three different approaches: -The FP8 KV cache implementation follows this workflow: +1. **No calibration (default scales):** + All quantization scales are set to `1.0`. + _Configure with:_ + ```python + kv_cache_dtype="fp8" + calculate_kv_scales=False + ``` -1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache -2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16) -3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor +2. **Random token calibration (on-the-fly):** + Scales are automatically estimated from a single batch of random tokens during warmup and then fixed. + _Configure with:_ + ```python + kv_cache_dtype="fp8" + calculate_kv_scales=True + ``` -This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations. +3. **[Recommended] Calibration with a dataset (via llm-compressor):** + Scales are estimated using a curated calibration dataset for maximum accuracy. + This requires the [llm-compressor](https://github.com/vllm-project/llm-compressor) library. + _See example below!_ -### Performance Impact +#### Additional `kv_cache_dtype` Options -The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either: +- `kv_cache_dtype="auto"`: Use the model's default data type +- `kv_cache_dtype="fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPUs) +- `kv_cache_dtype="fp8_e5m2"`: Supported on CUDA 11.8+ -- Processing longer context lengths for individual requests, or -- Handling more concurrent request batches +--- -However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized. +## Examples -Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization. +### 1. No Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=False`) -## Usage Example +All quantization scales are set to 1.0. -Here is an example of how to enable FP8 quantization: +```python +from vllm import LLM, SamplingParams -??? code +sampling_params = SamplingParams(temperature=0.7, top_p=0.8) +llm = LLM( + model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=False, +) +prompt = "London is the capital of" +out = llm.generate(prompt, sampling_params)[0].outputs[0].text +print(out) +``` - ```python - # To calculate kv cache scales on the fly enable the calculate_kv_scales - # parameter +--- - from vllm import LLM, SamplingParams +### 2. Random Token Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=True`) - sampling_params = SamplingParams(temperature=0.7, top_p=0.8) - llm = LLM( - model="meta-llama/Llama-2-7b-chat-hf", - kv_cache_dtype="fp8", - calculate_kv_scales=True, - ) - prompt = "London is the capital of" - out = llm.generate(prompt, sampling_params)[0].outputs[0].text - print(out) - ``` +Scales are automatically estimated from a single batch of tokens during warmup. -The `kv_cache_dtype` argument specifies the data type for KV cache storage: +```python +from vllm import LLM, SamplingParams -- `"auto"`: Uses the model's default "unquantized" data type -- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU) -- `"fp8_e5m2"`: Supported on CUDA 11.8+ +sampling_params = SamplingParams(temperature=0.7, top_p=0.8) +llm = LLM( + model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=True, +) +prompt = "London is the capital of" +out = llm.generate(prompt, sampling_params)[0].outputs[0].text +print(out) +``` -## Calibrated Scales for Better Accuracy +--- -For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process. +### 3. **[Recommended] Calibration Using a Dataset (with `llm-compressor`)** -### Installation +For the highest-quality quantization, we recommend calibrating against a dataset using `llm-compressor`. This enables advanced strategies such as per-attention-head quantization. -First, install the required dependencies: +#### Install the required package ```bash pip install llmcompressor ``` -### Example Usage +#### Example: Quantize Llama Attention & KV Cache to FP8 -Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern): - -??? code +```python +""" +Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy) +using llm-compressor one-shot calibration. +""" + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs + +# ----------------------------- +# Config +# ----------------------------- +MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +STRATEGY = "tensor" # or "attn_head" +NUM_CALIB_SAMPLES = 512 # Good starting value +MAX_SEQ_LEN = 2048 + +# ----------------------------- +# Helpers +# ----------------------------- +def process_and_tokenize(example, tokenizer: AutoTokenizer): + """Convert chat messages to tokens.""" + text = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer( + text, + padding=False, + max_length=MAX_SEQ_LEN, + truncation=True, + add_special_tokens=False, + ) - ```python - from datasets import load_dataset - from transformers import AutoModelForCausalLM, AutoTokenizer - from llmcompressor import oneshot +def build_recipe(strategy: str) -> QuantizationModifier: + fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy) + return QuantizationModifier( + config_groups={ + "attention": QuantizationScheme( + targets=["LlamaAttention"], # Quantize queries: q_scale + input_activations=fp8_args, + ) + }, + kv_cache_scheme=fp8_args, # Quantize KV cache: k/v_scale + ) - # Select model and load it - MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" - model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto") +# ----------------------------- +# Main +# ----------------------------- +def main(): + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]") + ds = ds.shuffle(seed=42) + ds = ds.map( + lambda ex: process_and_tokenize(ex, tokenizer), + remove_columns=ds.column_names, + ) - # Select calibration dataset - DATASET_ID = "HuggingFaceH4/ultrachat_200k" - DATASET_SPLIT = "train_sft" - - # Configure calibration parameters - NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point - MAX_SEQUENCE_LENGTH = 2048 - - # Load and preprocess dataset - ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) - ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - - def process_and_tokenize(example): - text = tokenizer.apply_chat_template(example["messages"], tokenize=False) - return tokenizer( - text, - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - ds = ds.map(process_and_tokenize, remove_columns=ds.column_names) - - # Configure quantization settings - recipe = """ - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - """ - - # Apply quantization + recipe = build_recipe(STRATEGY) oneshot( model=model, dataset=ds, recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, + max_seq_length=MAX_SEQ_LEN, + num_calibration_samples=NUM_CALIB_SAMPLES, ) - # Save quantized model: Llama-3.1-8B-Instruct-FP8-KV - SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV" - model.save_pretrained(SAVE_DIR, save_compressed=True) - tokenizer.save_pretrained(SAVE_DIR) - ``` - -The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales. + save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}" + model.save_pretrained(save_dir, save_compressed=True) + tokenizer.save_pretrained(save_dir) -When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales. - -```python -from vllm import LLM, SamplingParams - -sampling_params = SamplingParams(temperature=0.7, top_p=0.8) -llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8") -prompt = "London is the capital of" -out = llm.generate(prompt, sampling_params)[0].outputs[0].text -print(out) +if __name__ == "__main__": + main() ``` + +For more detailed and up-to-date examples, see the [`llm-compressor` official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache). diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 85be49e896c4..a130f9acbe8d 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -8,6 +8,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -19,6 +20,7 @@ HEAD_SIZES = [64, 80, 256] BLOCK_SIZES = [8, 16, 32] CACHE_LAYOUTS = ["NHD", "HND"] +KV_SCALE_TYPES = ["tensor", "attn_head"] # Parameters for MLA tests. KV_LORA_RANKS = [512] @@ -170,6 +172,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) +@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES) @pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS) @torch.inference_mode() def test_reshape_and_cache_flash( @@ -184,6 +187,7 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, kv_cache_layout: str, + kv_scale_type: str, implementation: str, ) -> None: set_random_seed(seed) @@ -193,6 +197,9 @@ def test_reshape_and_cache_flash( if implementation == "triton" and kv_cache_layout == "HND": pytest.skip("Triton implementation only supports NHD layout.") + if kv_scale_type == "attn_head" and implementation != "cuda": + pytest.skip("Only CUDA implementation supports attn_head scaling.") + # fp8 conversion requires continugous memory buffer. Reduce the number of # blocks and tokens to consume less memory. num_tokens = num_tokens // 2 @@ -220,8 +227,12 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = (key.amax() / 64.0).to(torch.float32) - v_scale = (value.amax() / 64.0).to(torch.float32) + if kv_scale_type == "tensor": + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + else: # "attn_head" + k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32) + v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32) def permute_and_compact(x): y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3) @@ -230,15 +241,27 @@ def permute_and_compact(x): key_cache_compact = permute_and_compact(key_cache) value_cache_compact = permute_and_compact(value_cache) + def convert_fp8_local(output, input, scale, kv_dtype): + fp8_input = input.view(torch.float8_e4m3fn) + if scale.numel() == 1: # per-tensor + result = scaled_dequantize( + fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype + ).reshape(*input.shape) + else: # per-head: broadcast scale along the head dimension + # Original code uses dim 2 for NHD, dim 1 for HND + if kv_cache_layout == "NHD": + result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1) + else: + result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1) + output.copy_(result) + # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) - ops.convert_fp8( - cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype - ) + convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype) cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) - ops.convert_fp8( - cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype + convert_fp8_local( + cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype ) else: cloned_key_cache = key_cache_compact.clone() @@ -289,15 +312,13 @@ def permute_and_compact(x): if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) - ops.convert_fp8( - result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype - ) + convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype) result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) - ops.convert_fp8( + convert_fp8_local( result_value_cache, value_cache_compact, - v_scale.item(), - kv_dtype=kv_cache_dtype, + v_scale, + kv_cache_dtype, ) # Run the reference implementation. diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 085c2a703e4c..f447777d5926 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -32,6 +32,7 @@ sparse_cutlass_supported, ) from vllm.platforms import current_platform +from vllm.v1.attention.backends.fa_utils import get_flash_attn_version # AITER only supports per-channel-per-channel INT8 gemm # and per-tensor-per-tensor INT8 GEMM. @@ -360,9 +361,26 @@ def check_model(model): @pytest.mark.skipif( not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." ) -def test_compressed_tensors_kv_cache(vllm_runner): - model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" - with vllm_runner(model_path, enforce_eager=True, kv_cache_dtype="fp8") as llm: +def test_compressed_tensors_kv_cache_fp8_per_tensor(vllm_runner): + model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-kvcache-fp8-tensor" + with vllm_runner(model_path) as llm: + output = llm.generate_greedy("Hello world!", max_tokens=4) + assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner): + model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-kvcache-fp8-attn_head" + try: + fa_version = get_flash_attn_version() + except Exception: + pytest.skip("This test requires FlashAttention backend.") + if fa_version is None or fa_version < 3: + pytest.skip("This test requires FlashAttention version >= 3.") + + with vllm_runner(model_path, attention_config={"backend": "FLASH_ATTN"}) as llm: output = llm.generate_greedy("Hello world!", max_tokens=4) assert output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8087dc708444..a1d796d50356 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -75,13 +75,16 @@ def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> layer._v_scale_float = 1.0 layer._prob_scale_float = 1.0 + # Initialize q/k/v range constants used by calc_kv_scales + layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + def _init_kv_cache_quant( layer: nn.Module, quant_config: QuantizationConfig | None, prefix: str, - kv_cache_dtype: str, - calculate_kv_scales: bool, ) -> None: """Initializes KV cache scaling factors and quantization method. @@ -94,16 +97,10 @@ def _init_kv_cache_quant( layer: The attention layer instance to initialize. quant_config: Optional quantization configuration. prefix: Layer name prefix for quantization method lookup. - kv_cache_dtype: The KV cache data type string. - calculate_kv_scales: Whether to calculate KV scales dynamically. """ - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - layer.kv_cache_dtype = kv_cache_dtype - layer.calculate_kv_scales = calculate_kv_scales + quant_method = ( + quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None + ) # Note [Register q/k/v/prob scales in state dict] # When calling model.to(device), only parameters/buffers in state dict are @@ -133,7 +130,7 @@ def _init_kv_cache_quant( assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior - if kv_cache_dtype == "fp8_e5m2": + if layer.kv_cache_dtype == "fp8_e5m2": raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.") # If quantization is enabled, we make "k_scale" and "v_scale" # parameters so that it can be loaded from the model checkpoint. @@ -197,9 +194,20 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False + + # llm-compressor mdls need to set cache_dtype to "fp8" manually. + if getattr(quant_config, "kv_cache_scheme", None) is not None: + kv_cache_dtype = "fp8" + calculate_kv_scales = False + if cache_config is not None: + cache_config.cache_dtype = "fp8" + cache_config.calculate_kv_scales = False + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( kv_cache_dtype, vllm_config.model_config ) + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales if num_kv_heads is None: num_kv_heads = num_heads assert num_heads % num_kv_heads == 0, ( @@ -208,15 +216,6 @@ def __init__( self.quant_config = quant_config self.layer_name = prefix - # Initialize KV cache quantization attributes - _init_kv_cache_quant( - self, - self.quant_config, - self.layer_name, - kv_cache_dtype, - calculate_kv_scales, - ) - self.num_heads = num_heads self.head_size = head_size self.head_size_v = self.head_size if head_size_v is None else head_size_v @@ -318,18 +317,24 @@ def __init__( for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] - # Initialize q/k/v range constants. - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + # Initialize KV cache quantization attributes + _init_kv_cache_quant(self, quant_config, prefix) # for attn backends supporting query quantization self.query_quant = None - if ( - self.kv_cache_dtype.startswith("fp8") - and self.impl.supports_quant_query_input + if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith( + "fp8" ): - self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + is_per_head = ( + hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads + ) + block_size = self.head_size * self.num_heads // self.num_kv_heads + self.query_quant = QuantFP8( + static=True, + group_shape=GroupShape(-1, block_size) + if is_per_head + else GroupShape.PER_TENSOR, + ) def forward( self, @@ -524,13 +529,9 @@ def __init__( self.quant_config = quant_config # Initialize KV cache quantization attributes - _init_kv_cache_quant( - self, - self.quant_config, - self.layer_name, - kv_cache_dtype, - calculate_kv_scales, - ) + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + _init_kv_cache_quant(self, quant_config, prefix) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index de50c9e8f588..8b1d564e27a5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import suppress +from functools import partial from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch @@ -19,6 +20,10 @@ import vllm.envs as envs from vllm.attention.layer import Attention +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( @@ -87,6 +92,8 @@ def __init__( kv_cache_scheme: dict[str, Any] | None = None, config: dict[str, Any] | None = None, transform_config: dict[str, Any] | None = None, + total_num_heads: int | None = None, + total_num_kv_heads: int | None = None, ): super().__init__() self.ignore = ignore @@ -97,6 +104,8 @@ def __init__( self.sparsity_scheme_map = sparsity_scheme_map self.sparsity_ignore_list = sparsity_ignore_list self.config = config + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads if transform_config: self.transform_config = TransformConfig.model_validate(transform_config) @@ -200,13 +209,29 @@ def _add_fused_moe_to_target_scheme_map(self): @classmethod def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": + # We keep only config groups which are not doing Attention quantization + # because Attention quantization on its own is not supported by vLLM. + # It is coupled with KV-cache quantization, and if scales are present in the + # checkpoint, they will be used properly. + grps_without_attn_quant = {} + for k, v in config["config_groups"].items(): + # e.g. LlamaAttention, Qwen3Attention, etc. + if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"): + logger.warning( + "Skipping CompressedTensors config group for %s. Attention quant " + "is coupled with KV-cache quantization in vLLM.", + v["targets"][0], + ) + continue + grps_without_attn_quant[k] = v + config["config_groups"] = grps_without_attn_quant + ignore: list[str] = cast(list[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) target_scheme_map = cls._quantization_scheme_map_from_config(config=config) sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( config=config ) - transform_config = config.get("transform_config") return cls( target_scheme_map=target_scheme_map, @@ -215,7 +240,10 @@ def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": sparsity_scheme_map=sparsity_scheme_map, sparsity_ignore_list=sparsity_ignore_list, config=config, - transform_config=transform_config, + transform_config=config.get("transform_config"), + kv_cache_scheme=config.get("kv_cache_scheme"), + total_num_heads=config.get("total_num_heads"), + total_num_kv_heads=config.get("total_num_kv_heads"), ) @classmethod @@ -791,22 +819,6 @@ def get_scheme_dict( return None - def get_cache_scale(self, name: str) -> str | None: - """ - Check whether the param name matches the format for k/v cache scales - in compressed-tensors. If this is the case, return its equivalent - param name expected by vLLM - - :param name: param name - :return: matching param name for KV cache scale in vLLM - """ - if name.endswith(".output_scale") and ".k_proj" in name: - return name.replace(".k_proj.output_scale", ".attn.k_scale") - if name.endswith(".output_scale") and ".v_proj" in name: - return name.replace(".v_proj.output_scale", ".attn.v_scale") - # If no matches, return None - return None - def has_blocked_weights(self) -> bool: for scheme in self.target_scheme_map.values(): weight_quant = scheme.get("weights") @@ -965,12 +977,16 @@ def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None): f"received num_bits={num_bits}, type={type_}" ) + # TODO: delegate validation to compressed-tensors library so that we have a + # single source of truth. Right now this is not possible until the next release + # of compressed-tensors. strategy = kv_cache_scheme.get("strategy") - if strategy != "tensor": + supported_strategies = ("tensor", "attn_head") + if strategy not in supported_strategies: raise NotImplementedError( - "Only support per-tensor scaling factor " - "for compressed-tensors KV cache. " - f"Expected strategy: tensor, found strategy: {strategy}" + "Invalid strategy for compressed-tensors KV cache. " + f"Expected strategies: {supported_strategies}, found strategy:" + f" {strategy}" ) is_symmetric = kv_cache_scheme.get("symmetric") @@ -980,3 +996,133 @@ def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None): "for compressed-tensors KV cache. " f"However found symmetric: {is_symmetric}" ) + + def create_weights(self, layer: torch.nn.Module): + """ + Initialize placeholder scales and zero points to enable loading of + quantized params from compressed-tensors checkpoints. + """ + strategy = None # for backward compatibility + if ( + hasattr(self.quant_config, "kv_cache_scheme") + and self.quant_config.kv_cache_scheme is not None + ): + strategy = self.quant_config.kv_cache_scheme["strategy"] + + if strategy == "attn_head": + assert layer.impl.supports_per_head_quant_scales, ( + f"Layer {layer.__class__.__name__} with implementation " + f"{layer.impl.__class__.__name__} does not support per-head scales." + ) + n_scales = int(layer.num_kv_heads) + else: + n_scales = 1 + + layer.k_scale = torch.nn.Parameter( + torch.ones(n_scales, requires_grad=False, dtype=torch.float32) + ) + layer.v_scale = torch.nn.Parameter( + torch.ones(n_scales, requires_grad=False, dtype=torch.float32) + ) + layer.q_scale = torch.nn.Parameter( + torch.ones(n_scales, requires_grad=False, dtype=torch.float32) + ) + + # Zero points are not used in vLLM as currently only symmetric quantization is + # supported. We need to create them here to enable loading of llm-compressor + # checkpoints which contain them irrespective of the symmetric/asymmetric + # scheme used during quantization. + layer.k_zero_point = torch.nn.Parameter( + torch.zeros(n_scales, requires_grad=False) + ) + layer.v_zero_point = torch.nn.Parameter( + torch.zeros(n_scales, requires_grad=False) + ) + layer.q_zero_point = torch.nn.Parameter( + torch.zeros(n_scales, requires_grad=False) + ) + + # TP-aware loading for attn_head strategy follows attention head partitioning: + # - q_scale is partitioned over query heads. + # - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size, + # and replicated when total_kv_heads < tp_size. + if strategy == "attn_head": + + def _tp_aware_loader( + param: torch.Tensor, + loaded_weight: torch.Tensor, + kind: Literal["q", "k", "v"], + param_type: Literal["scale", "zero_point"], + ): + # Zero-points are not used as vLLM only supports symmetric quantization + if param_type == "zero_point": + return + + # LLM-Compressor stores scales as 3D tensors of shape [num_heads, 1, 1] + loaded_weight = loaded_weight.flatten() + + # FlashAttn expects [num_kv_heads] instead of [num_heads] for q_scale. + # We reduce by taking the max scale in each attention head group. + if kind == "q": + reduction_factor = ( + self.quant_config.total_num_heads # type: ignore[attr-defined] + // self.quant_config.total_num_kv_heads # type: ignore[attr-defined] + ) + loaded_weight = torch.amax( + loaded_weight.view(-1, reduction_factor), dim=1 + ) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if layer.num_kv_heads * tp_size == self.quant_config.total_num_kv_heads: # type: ignore[attr-defined] + # heads evenly distributed + loaded_weight = loaded_weight[ + tp_rank * layer.num_kv_heads : (tp_rank + 1) + * layer.num_kv_heads + ] + else: + # heads replicated to match TP size + assert layer.num_kv_heads == 1 + replicas = tp_size // self.quant_config.total_num_kv_heads # type: ignore[attr-defined] + shard_rank = tp_rank // replicas + loaded_weight = loaded_weight[shard_rank : shard_rank + 1] + + param.data.copy_(loaded_weight.to(dtype=param.dtype)) + + layer.q_scale.weight_loader = partial( + _tp_aware_loader, kind="q", param_type="scale" + ) + layer.k_scale.weight_loader = partial( + _tp_aware_loader, kind="k", param_type="scale" + ) + layer.v_scale.weight_loader = partial( + _tp_aware_loader, kind="v", param_type="scale" + ) + + layer.q_zero_point.weight_loader = partial( + _tp_aware_loader, kind="q", param_type="zero_point" + ) + layer.k_zero_point.weight_loader = partial( + _tp_aware_loader, kind="k", param_type="zero_point" + ) + layer.v_zero_point.weight_loader = partial( + _tp_aware_loader, kind="v", param_type="zero_point" + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """ + Override the default vLLM placeholder scales with the llm-compressor loaded + scales. Zero points are not used as only symmetric quantization is supported. + """ + layer._k_scale = layer.k_scale + layer._v_scale = layer.v_scale + layer._q_scale = layer.q_scale + + # Discard all placeholders. + del layer.k_scale + del layer.v_scale + del layer.q_scale + del layer.k_zero_point + del layer.v_zero_point + del layer.q_zero_point diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index c1a901c37a0b..5cd2fd18c350 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -11,6 +11,7 @@ GroupShape, get_fp8_min_max, group_broadcast, + prep_scale_for_group_broadcast, ) from vllm.platforms import current_platform @@ -40,7 +41,7 @@ def __init__( """ :param static: static or dynamic quantization :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, - or arbitrary block size) + PER_CHANNEL, or arbitrary block size) :param num_token_padding: Pad the token dimension of output to this size :param column_major_scales: For group quantization, output scales in @@ -157,6 +158,8 @@ def forward_native( x_max = x.abs().max().unsqueeze(-1).to(torch.float32) scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + else: + scale = prep_scale_for_group_broadcast(scale, x, self.group_shape) # Even for dynamic per-token scales, # reciprocal performs slightly better than division diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index bc7458444412..727315106d0a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -191,6 +191,51 @@ def group_broadcast(t, shape): return t +def prep_scale_for_group_broadcast( + scale: torch.Tensor, + x: torch.Tensor, + group_shape: GroupShape | None, +) -> torch.Tensor: + """ + Prepare the input quantization scale for group broadcasting. + + Args: + scale: The scale tensor (scalar or 1D). + x: Target tensor whose shape determines broadcast dimensions. + group_shape: GroupShape to broadcast over. + + Returns: + scale reshaped for correct broadcasting. + """ + if scale.numel() == 1: + # For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)). + # This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply, + # where the "per_tensor_activations" check relies on "x_scale.dim() < 2": + # per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 + # For all other cases, reshape scalar scales to (1, 1) for broadcasting. + return ( + scale + if group_shape is not None and group_shape.is_per_tensor() + else scale.reshape(1, 1) + ) + if scale.ndim == 1: + assert group_shape is not None, ( + "group_shape must be provided to correctly broadcast 1D scale" + ) + rows, cols = _normalize_quant_group_shape(x, group_shape) + # Determine broadcasting dimension: either rows or columns match group size + if rows == x.shape[-2]: + scale = scale.unsqueeze(-2) + elif cols == x.shape[-1]: + scale = scale.unsqueeze(-1) + else: + raise ValueError( + f"1D scale with shape {scale.shape} cannot be broadcast to x with shape" + f" {x.shape}, group_shape={(rows, cols)}" + ) + return scale + + # Quantize assuming once scale per group of elements with shape group_shape, # example group shapes: # * (-1, -1) for per-tensor quantization @@ -241,7 +286,7 @@ def scaled_quantize( _, fp8_max = get_fp8_min_max() scale = fp8_max / amax - # Apply scale and convert form: + # Apply scale and convert from: # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) x_scl_sat = ( (x_blkd_permd * scale.unsqueeze(-1)) @@ -261,29 +306,7 @@ def scaled_dequantize( group_shape: GroupShape | None = None, out_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - if group_shape is not None: - group_shape = _normalize_quant_group_shape(x_q, group_shape) - - if x_s.numel() == 1: # scalar - x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1) - if x_s.ndim == 1: - if group_shape is None: - raise AssertionError( - "if x_s is 1D tensor, group_shape must be provided otherwise " - "its ambiguous which dimension to broadcast x_s to" - ) - # unsqueeze the scales for the dimension where we want to broadcast - # across the full extent - if group_shape[0] == x_q.shape[-2]: - x_s = x_s.unsqueeze(-2) - elif group_shape[1] == x_q.shape[-1]: - x_s = x_s.unsqueeze(-1) - else: - raise AssertionError( - "if x_s is a vector we should be broadcasting it to the full " - "extent of one of the dimensions" - ) - + x_s = prep_scale_for_group_broadcast(x_s, x_q, group_shape) if group_shape is not None: assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0] diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index b20638c7eb28..7ea3bb2ebd19 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -246,6 +246,23 @@ def get_quant_config( # compressed-tensors uses a compressions_config hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + # Pipe information about heads to enable TP-aware loading of attn_head scales + if ( + hf_quant_config is not None + and hf_quant_config.get("quant_method") == "compressed-tensors" + ): + if hf_text_config is not None: + n_heads = getattr(hf_text_config, "num_attention_heads", None) + n_kv_heads = getattr(hf_text_config, "num_key_value_heads", None) + else: + n_heads = getattr(model_config.hf_config, "num_attention_heads", None) + n_kv_heads = getattr(model_config.hf_config, "num_key_value_heads", None) + + hf_quant_config["total_num_heads"] = n_heads + hf_quant_config["total_num_kv_heads"] = ( + n_kv_heads if n_kv_heads is not None else n_heads + ) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) @@ -1157,11 +1174,21 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: # .mixer.attn.{k,v}_scale (r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"), # Default format: .{k,v}_scale -> .attn.{k,v}_scale - (r"\.([kv])_scale$", r".attn.\1_scale"), + (r"\.([qkv])_scale$", r".attn.\1_scale"), + (r"\.([qkv])_zero_point$", r".attn.\1_zero_point"), ] # Check if name ends with k_scale or v_scale - if name.endswith((".k_scale", ".v_scale")): + if name.endswith( + ( + ".k_scale", + ".v_scale", + ".q_scale", + ".k_zero_point", + ".v_zero_point", + ".q_zero_point", + ) + ): import regex as re for pattern, replacement in scale_mapping_patterns: diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 7d43735c0053..b1b6cdd814fa 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -437,7 +437,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - if "scale" in name: + if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index b3887b16f4d7..5616ffee682b 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -303,7 +303,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params.add(scale_name) continue - if "scale" in name: + if "scale" in name or "zero_point" in name: remapped_name = maybe_remap_kv_scale_name(name, params_dict) if remapped_name is None: continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e265308937d4..979cd9fdd16b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -465,8 +465,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - if "scale" in name: - # Remapping the name of FP8 kv-scale. + if "scale" in name or "zero_point" in name: + # Remapping the name of FP8 kv-scale or zero point. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 05cb456e7776..99a69adf1fc3 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -140,8 +140,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - # Remapping the name FP8 kv-scale - if "scale" in name: + # Remapping the name FP8 kv-scale or zero point. + if "scale" in name or "zero_point" in name: name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7a57644db1b1..e47a3ee74c6b 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -238,8 +238,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - # Remapping the name FP8 kv-scale - if "scale" in name: + # Remapping the name FP8 kv-scale or zero point. + if "scale" in name or "zero_point" in name: name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8655cf66d209..ef34c6daddfd 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -661,7 +661,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "scale" in name: + if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index da0688f71958..6ff4f0e84320 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -342,7 +342,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - if "scale" in name: + if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index b4dcea1052d8..16cac42e54ab 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -620,6 +620,7 @@ class AttentionImpl(ABC, Generic[T]): # TODO add support to more backends: # https://github.com/vllm-project/vllm/issues/25584 supports_quant_query_input: bool = False + supports_per_head_quant_scales: bool = False dcp_world_size: int dcp_rank: int diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 6fec5001badd..ccb51e2159cb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -576,6 +576,11 @@ def __init__( ) self.supports_quant_query_input = True + self.supports_per_head_quant_scales = ( + self.vllm_flash_attn_version >= 3 + if self.vllm_flash_attn_version is not None + else False + ) def forward( self, @@ -691,6 +696,10 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + q_descale = layer._q_scale.expand(descale_shape) + k_descale = layer._k_scale.expand(descale_shape) + v_descale = layer._v_scale.expand(descale_shape) + if self.dcp_world_size > 1: self._forward_with_dcp( query[:num_actual_tokens], @@ -700,9 +709,9 @@ def forward( value_cache, output[:num_actual_tokens], attn_metadata, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, ) return output else: @@ -728,9 +737,9 @@ def forward( softcap=self.logits_soft_cap, scheduler_metadata=scheduler_metadata, fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, num_splits=attn_metadata.max_num_splits, s_aux=self.sinks, )