diff --git a/docs/benchmarking/dashboard.md b/docs/benchmarking/dashboard.md index 701fb16ae2cf..826abd64ab62 100644 --- a/docs/benchmarking/dashboard.md +++ b/docs/benchmarking/dashboard.md @@ -13,14 +13,14 @@ For x86 CPU environment, please use the image with "-cpu" postfix. For AArch64 C Here is an example for docker run command for CPU. For GPUs skip setting the `ON_CPU` env var. ```bash -export VLLM_COMMIT=1da94e673c257373280026f75ceb4effac80e892 # use full commit hash from the main branch +export VLLM_COMMIT=7f42dc20bb2800d09faa72b26f25d54e26f1b694 # use full commit hash from the main branch export HF_TOKEN= if [[ "$(uname -m)" == aarch64 || "$(uname -m)" == arm64 ]]; then IMG_SUFFIX="arm64-cpu" else IMG_SUFFIX="cpu" fi -docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN=$HF_TOKEN -e ON_ARM64_CPU=1 --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${VLLM_COMMIT}-${IMG_SUFFIX} +docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN=$HF_TOKEN -e ON_CPU=1 --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${VLLM_COMMIT}-${IMG_SUFFIX} ``` Then, run below command inside the docker instance. diff --git a/docs/models/hardware_supported_models/cpu.md b/docs/models/hardware_supported_models/cpu.md index 811778b2ad52..ff228cb8b76a 100644 --- a/docs/models/hardware_supported_models/cpu.md +++ b/docs/models/hardware_supported_models/cpu.md @@ -7,7 +7,7 @@ | [Intel® Xeon® 6 Processors](https://www.intel.com/content/www/us/en/products/details/processors/xeon.html) | | [Intel® Xeon® 5 Processors](https://www.intel.com/content/www/us/en/products/docs/processors/xeon/5th-gen-xeon-scalable-processors.html) | -## Supported Models +## Recommended Models ### Text-only Language Models diff --git a/docs/models/hardware_supported_models/xpu.md b/docs/models/hardware_supported_models/xpu.md index 7b8dcf5c9af2..6817e0021ffe 100644 --- a/docs/models/hardware_supported_models/xpu.md +++ b/docs/models/hardware_supported_models/xpu.md @@ -6,7 +6,7 @@ | ----------------------------------------- | | [Intel® Arc™ Pro B-Series Graphics](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/workstations/b-series/overview.html) | -## Supported Models +## Recommended Models ### Text-only Language Models diff --git a/tests/v1/entrypoints/openai/serving_responses/test_stateful.py b/tests/v1/entrypoints/openai/serving_responses/test_stateful.py index 6f7edb6bd7e7..da63e92a1e7e 100644 --- a/tests/v1/entrypoints/openai/serving_responses/test_stateful.py +++ b/tests/v1/entrypoints/openai/serving_responses/test_stateful.py @@ -70,15 +70,28 @@ async def test_background_cancel(client: openai.AsyncOpenAI): assert response.status == "queued" # Cancel the response before it is completed. - # FIXME: This test can be flaky. - await asyncio.sleep(0.5) + # Poll until the response is no longer queued (started processing) or timeout + loop = asyncio.get_running_loop() + start_time = loop.time() + max_wait_seconds = 5.0 + poll_interval = 0.1 + while loop.time() - start_time < max_wait_seconds: + response = await client.responses.retrieve(response.id) + if response.status != "queued": + # Started processing or completed - try to cancel + break + await asyncio.sleep(poll_interval) + response = await client.responses.cancel(response.id) assert response.status == "cancelled" - # Make sure the response status remains unchanged. - await asyncio.sleep(5) - response = await client.responses.retrieve(response.id) - assert response.status == "cancelled" + # Make sure the response status remains unchanged after some time. + max_retries = 10 + for _ in range(max_retries): + await asyncio.sleep(0.5) + response = await client.responses.retrieve(response.id) + # Verify status is still cancelled + assert response.status == "cancelled" @pytest.mark.asyncio diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 9fa2d8ae3fcd..f4e834f64060 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -213,7 +213,8 @@ def compute_hash(self) -> str: factors: list[Any] = [ self.mm_encoder_attn_backend.name if self.mm_encoder_attn_backend is not None - else None + else None, + self.mm_encoder_tp_mode, ] hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 69cce72f1dd7..1f8f5e5dbff9 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -263,6 +263,12 @@ def compute_hash(self) -> str: vllm_factors.append(__version__) if self.model_config: vllm_factors.append(self.model_config.compute_hash()) + if ( + self.compilation_config + and getattr(self.compilation_config, "compile_mm_encoder", False) + and self.model_config.multimodal_config + ): + vllm_factors.append(self.model_config.multimodal_config.compute_hash()) else: vllm_factors.append("None") if self.cache_config: diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 55c4c0e187bd..9371a977f96f 100755 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -202,13 +202,16 @@ from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, get_and_maybe_dequant_weights, ) from vllm.platforms import current_platform @@ -287,6 +290,37 @@ def dynamic_per_batched_tensor_quant( logger = init_logger(__name__) + +@CustomOp.register("mla_decode_concat_quant_fp8") +class _DecodeConcatQuantFP8(QuantFP8): + """ + QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before + quantization. When disabled, forward_native is compiled via torch.compile, + fusing cat/reshape/quant/view together. + """ + + def _make_forward(quant_fn): # noqa: N805 + """Factory to create forward methods that concat before quantization.""" + + def forward( + self, + decode_ql_nope: torch.Tensor, + decode_q_pe: torch.Tensor, + scale: torch.Tensor, + scale_ub: torch.Tensor | None = None, + ) -> torch.Tensor: + decode_q0 = torch.cat((decode_ql_nope, decode_q_pe), dim=-1) + decode_q_flat = decode_q0.reshape(decode_q0.shape[0], -1) + decode_q, _ = quant_fn(self, decode_q_flat, scale, scale_ub) + return decode_q.view(decode_q0.shape) + + return forward + + forward_native = _make_forward(QuantFP8.forward_native) # type: ignore[arg-type] + forward_cuda = _make_forward(QuantFP8.forward_cuda) # type: ignore[arg-type] + forward_hip = _make_forward(QuantFP8.forward_hip) # type: ignore[arg-type] + + CUDNN_WORKSPACE_SIZE = 12800 @@ -1398,6 +1432,11 @@ def __init__(self, *args, **kwargs) -> None: self.cp_kv_cache_interleave_size: int = ( get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size ) + self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs @@ -2048,29 +2087,11 @@ def forward( decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: - ql_nope_shape = decode_ql_nope.shape - q_pe_shape = decode_q_pe.shape assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] - decode_q_shape = ( - ql_nope_shape[0], - ql_nope_shape[1], - ql_nope_shape[2] + q_pe_shape[2], - ) - # Using empty and copy since torch.cat introduces significant overhead. - decode_q0 = torch.empty( - decode_q_shape, - device=decode_ql_nope.device, - dtype=decode_ql_nope.dtype, - ) - decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope) - decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe) - - decode_q, _ = ops.scaled_fp8_quant( - decode_q0.view(decode_q_shape[0], -1), - layer._q_scale, + decode_q = self._decode_concat_quant_fp8_op( + decode_ql_nope, decode_q_pe, layer._q_scale ) - decode_q = decode_q.view(decode_q_shape) else: decode_q = (decode_ql_nope, decode_q_pe) if self.dcp_world_size > 1: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c8eef850c497..4f210b8e5f0c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -511,7 +511,9 @@ def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): ) def is_patch_merger(weight: tuple[str, torch.Tensor]): - return weight[0].startswith("patch_merger") + return weight[0].startswith( + ("patch_merger", "multi_modal_projector.patch_merger") + ) def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") @@ -554,18 +556,23 @@ def llm_weights_generator(): if self.patch_merger is None: continue # Load vision patch merger weights directly - trimmed_name = ".".join(name.split(".")[1:]) - param = patch_merger_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + if name.startswith("multi_modal_projector.patch_merger"): + trimmed_name = ".".join(name.split(".")[2:]) + else: + trimmed_name = ".".join(name.split(".")[1:]) + param = patch_merger_dict.get(trimmed_name) + if param is not None: + with torch.no_grad(): + default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): if self.pre_mm_projector_norm is None: continue # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) - param = pre_mm_projector_norm_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + param = pre_mm_projector_norm_dict.get(trimmed_name) + if param is not None: + with torch.no_grad(): + default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): if self.vision_language_adapter is None: continue diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 43b44fdaf665..6b605bf2ff7e 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -237,7 +237,7 @@ def free(self, request: Request) -> None: Typically called when a request is finished, cancelled, or aborted. """ - input_ids = self.get_cached_input_ids(request).copy() + input_ids = self.get_cached_input_ids(request) for input_id in input_ids: self.free_encoder_input(request, input_id)