Skip to content

Commit 7b67ba4

Browse files
committed
correctly handle num_past_tokens > sliding_window case
1 parent 09ef5b3 commit 7b67ba4

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

examples/python/run_llama_batched_vllm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,17 +274,18 @@ def _prepare_eval_queries(
274274

275275
positions += [num_past_tokens + i for i in range(num_queries)]
276276

277-
if sliding_window:
278-
seq_lens.append(min(num_past_tokens + num_queries, sliding_window))
279-
num_past = min(num_past_tokens, sliding_window)
280-
past_slot_mapping += all_slot_mappings[request_id][:num_past]
281-
slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries]
277+
if sliding_window and num_past_tokens + num_queries >= sliding_window:
278+
seq_lens.append(sliding_window)
279+
past_slot_mapping += all_slot_mappings[request_id][
280+
num_past_tokens - (sliding_window - num_queries) : num_past_tokens
281+
]
282282
else:
283283
seq_lens.append(num_past_tokens + num_queries)
284284
past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens]
285-
slot_mapping += all_slot_mappings[request_id][
286-
num_past_tokens : num_past_tokens + num_queries
287-
]
285+
286+
slot_mapping += all_slot_mappings[request_id][
287+
num_past_tokens : num_past_tokens + num_queries
288+
]
288289

289290
permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list(
290291
range(query_offset, query_offset + num_queries)

0 commit comments

Comments
 (0)