From 86151e96dbf3d7bbcdf085ee936c1feeb703edd7 Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Thu, 16 Oct 2025 15:57:24 +0000 Subject: [PATCH 1/2] [fix][cpu] fix prefill attention in CPU attention backend - Disables prefix caching because prefill attention can't handle paged KV cache - Fixes Q/K/V used during prefill on mixed prefill/decode requests Signed-off-by: Fadi Arafeh --- vllm/engine/arg_utils.py | 9 ++++++++- vllm/v1/attention/backends/cpu_attn.py | 10 +++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 654857315b15..3ce2afd1025c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1281,7 +1281,7 @@ def create_engine_config( # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 + # Disable chunked prefill and prefix caching for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( CpuArchEnum.POWERPC, CpuArchEnum.S390X, @@ -1294,6 +1294,13 @@ def create_engine_config( "disabling it for V1 backend." ) self.enable_chunked_prefill = False + logger.info( + "Prefix caching is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) + self.enable_prefix_caching = False + assert self.enable_chunked_prefill is not None sliding_window: int | None = None diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 211eefdb6c11..0d3e1729ff20 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -412,7 +412,7 @@ def build( num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled - seq_lens=seq_lens_cpu.tolist(), + seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode decode_max_seq_len=max_decode_seq_len, # decode decode_block_tables=block_table_tensor[:num_decodes], # decode @@ -617,7 +617,6 @@ def forward( prefill_meta.prefill_block_tables, self.alibi_slopes, ) - if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata." @@ -686,7 +685,12 @@ def _run_sdpa_forward( causal_attn = attn_type == AttentionType.DECODER seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) - start_q, start_kv = 0, 0 + # Incoming Q and KV contain decoded tokens as well, hence start at an offset + # equal to num_decode_tokens since decode requests appear first + start_q, start_kv = ( + attn_metadata.num_decode_tokens, + attn_metadata.num_decode_tokens, + ) for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv From bf57438dff51b4566d6cc613a84a11c898409f35 Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Sat, 18 Oct 2025 09:33:54 +0000 Subject: [PATCH 2/2] fix formatting Signed-off-by: Fadi Arafeh --- vllm/engine/arg_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 025c4430c4bc..11d1a74d7af1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1292,7 +1292,8 @@ def create_engine_config( # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) - # Disable chunked prefill and prefix caching for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 + # Disable chunked prefill and prefix caching for: + # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( CpuArchEnum.POWERPC, CpuArchEnum.S390X,