Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0]
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
Expand All @@ -122,7 +122,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0]
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]

# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
Expand All @@ -137,7 +137,7 @@ def allocate(self, seq_group: SequenceGroup) -> None:
block_table.append(block)

# Assign the block table for each sequence.
for seq in seq_group.get_seqs():
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()

def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Expand Down
12 changes: 7 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,17 @@ def _schedule(self) -> SchedulerOutputs:
while self.waiting:
seq_group = self.waiting[0]

assert seq_group.num_seqs() == 1, (
waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
num_prompt_tokens = waiting_seqs[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
for seq in seq_group.get_seqs():
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
Expand All @@ -161,7 +163,7 @@ def _schedule(self) -> SchedulerOutputs:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager")
for seq in seq_group.get_seqs():
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
Expand Down Expand Up @@ -317,7 +319,7 @@ def free_finished_seq_groups(self) -> None:

def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs():
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING

def _append_slot(
Expand Down