Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,8 @@ def create_engine_config(

# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
# Disable chunked prefill and prefix caching for:
# POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
Expand All @@ -1306,6 +1307,13 @@ def create_engine_config(
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False

assert self.enable_chunked_prefill is not None

sliding_window: int | None = None
Expand Down
10 changes: 7 additions & 3 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def build(
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=seq_lens_cpu.tolist(),
seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
decode_max_seq_len=max_decode_seq_len, # decode
decode_block_tables=block_table_tensor[:num_decodes], # decode
Expand Down Expand Up @@ -617,7 +617,6 @@ def forward(
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)

if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata."
Expand Down Expand Up @@ -686,7 +685,12 @@ def _run_sdpa_forward(
causal_attn = attn_type == AttentionType.DECODER

seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
# Incoming Q and KV contain decoded tokens as well, hence start at an offset
# equal to num_decode_tokens since decode requests appear first
start_q, start_kv = (
attn_metadata.num_decode_tokens,
attn_metadata.num_decode_tokens,
)
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
Expand Down