Skip to content
Closed
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: "default"
repo: ""
- name: "vLLM:main"
repo: "git+https://github.com/vllm-project/vllm --branch main"
repo: "git+https://github.com/vllm-project/vllm@02cabff207ca68094a73ba21296c82cdbcb1d1a5"
test_suite:
- name: "static batching"
markers: "cpu and decoder and not cb"
Expand Down
72 changes: 38 additions & 34 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,32 +241,31 @@ def update_states(self, scheduler_output: SchedulerOutput):
#
# NOTE: req_state.output_token_ids is being mutated.

for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]

# Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) -
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1])
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:])
new_token_ids[-num_new_tokens:])

req_index = self.input_batch.get_req_index(req_id)
# Add new_token_ids to token_ids_cpu.
# TODO: Update for spec decoding in the future
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
req_index, start_token_index:end_token_index] = new_token_ids

if scheduler_output.finished_req_ids:
for req_id in scheduler_output.finished_req_ids:
Expand All @@ -277,8 +276,7 @@ def update_states(self, scheduler_output: SchedulerOutput):
def _prepare_prompt(self, _: list[NewRequestData]) -> ModelForwardInputs:
raise NotImplementedError

def _prepare_decode(self,
_: list[CachedRequestData]) -> ModelForwardInputs:
def _prepare_decode(self, _: CachedRequestData) -> ModelForwardInputs:
raise NotImplementedError

def prepare_model_input(
Expand All @@ -291,7 +289,7 @@ def prepare_model_input(
# Prepare input tensors.
if is_prompt:
# Assert no running requests
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs.req_ids) == 0

return self._prepare_prompt(scheduler_output.scheduled_new_reqs)
else:
Expand Down Expand Up @@ -455,19 +453,21 @@ def _prepare_prompt(

def _prepare_decode(
self,
cached_requests: list[CachedRequestData],
cached_request_data: CachedRequestData,
) -> ModelForwardInputs:
assert len(cached_requests) > 0
assert len(cached_request_data.req_ids) > 0
input_tokens: list[list[int]] = [
[0] for _ in range(self._position_ids.shape[0])
]

for cached_request in cached_requests:
for i, req_id in enumerate(cached_request_data.req_ids):
# TODO: Will this always just be one token ID if there's no spec
# or jump decoding?
generation_token = cached_request.new_token_ids[-1]
input_tokens[self.input_batch.req_id_to_index[
cached_request.req_id]] = [generation_token]
new_token_ids = cached_request_data.new_token_ids[i]
generation_token = new_token_ids[-1]
input_tokens[self.input_batch.req_id_to_index[req_id]] = [
generation_token
]

# update position ids and attention mask
self._update_position_ids()
Expand Down Expand Up @@ -752,48 +752,52 @@ def _prepare_prompt(

def _prepare_decode(
self,
cached_requests: list[CachedRequestData],
cached_request_data: CachedRequestData,
) -> ModelForwardInputs:
assert len(cached_requests) > 0
assert len(cached_request_data.req_ids) > 0

input_tokens = []
input_positions = []
block_table = []
slot_mapping = []
left_padded_prompt_mask = []
self.model.indices = torch.ones(len(cached_requests),
self.model.indices = torch.ones(len(cached_request_data.req_ids),
dtype=torch.bool,
device="cpu")

assert len(self.input_batch.req_id_to_index) == len(cached_requests)
assert len(self.input_batch.req_id_to_index) == len(
cached_request_data.req_ids)
# TODO(wallas): I think we can do better here, without sorting or
# creating an intermediary dictionary
cached_reqs_map = {c.req_id: c for c in cached_requests}
cached_reqs_map = {
req_id: i
for i, req_id in enumerate(cached_request_data.req_ids)
}
req_ids = self.input_batch.sorted_requests_ids

for req_id in req_ids:
# TODO: Will this always just be one token ID if there's no spec
# or jump decoding?
cached_request = cached_reqs_map[req_id]

# adding new blocks if needed
if self.tkv // self.block_size + 1 > len(
self.req_ids2blocks[cached_request.req_id]):
self.req_ids2blocks[cached_request.req_id].append(
self.free_blocks.popleft())
block_table.append(self.req_ids2blocks[cached_request.req_id])
self.req_ids2blocks[req_id]):
self.req_ids2blocks[req_id].append(self.free_blocks.popleft())
block_table.append(self.req_ids2blocks[req_id])
# slot_mapping for all blocks of sequence
start_slot = block_table[-1][-1] * self.block_size
offset = self.tkv % self.block_size
slot = [start_slot + offset]
slot_mapping.append(slot)

generation_token = cached_request.new_token_ids[-1]
new_token_ids = cached_request_data.new_token_ids[
cached_reqs_map[req_id]]
generation_token = new_token_ids[-1]
input_tokens.append([generation_token])
seq_len = cached_request.num_computed_tokens
seq_len = cached_request_data.num_computed_tokens[
cached_reqs_map[req_id]]
input_positions.append([seq_len])

req_state = self.requests[cached_request.req_id]
req_state = self.requests[req_id]
left_padded_prompt_mask.append(req_state.left_padding)

input_tokens = torch.tensor(input_tokens,
Expand All @@ -819,7 +823,7 @@ def _prepare_decode(
dtype=torch.int64)

# add pads for min decode batch size of 2 (Spyre compiler constraint)
if len(cached_requests) == 1:
if len(cached_request_data.req_ids) == 1:
padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu")
self.model.indices = torch.cat(
(self.model.indices, padd_seq_indices), -1)
Expand Down
122 changes: 68 additions & 54 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
for i, req in enumerate(dummy_requests):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[req],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
Expand All @@ -359,45 +359,50 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
logger.info("Warmup prefill %d/%d...", i + 1, batch_size)
self.execute_model(scheduler_output)

# one decode iteration across both sequences
cached_requests = [
CachedRequestData(
req_id=req.req_id,
resumed_from_preemption=False,
new_token_ids=[
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
], # placeholder token
new_block_ids=req.block_ids,
num_computed_tokens=prompt_len,
) for req in dummy_requests
]

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_requests,
num_scheduled_tokens={
f"warmup-{i}": 1
for i in range(batch_size)
},
total_num_scheduled_tokens=batch_size,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
logger.info("Warmup decode 1/1...")
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=dummy_requests)
# one decode iteration across both sequences
req_ids = []
new_token_ids = []
new_block_ids = []
num_computed_tokens = []
for req in dummy_requests:
req_ids.append(req.req_id)
new_token_ids.append([
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
]) # placeholder token
new_block_ids.append([req.block_ids])
num_computed_tokens.append(prompt_len)
cached_request_data = CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=False,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_request_data,
num_scheduled_tokens={f"warmup-{i}": 1
for i in range(batch_size)},
total_num_scheduled_tokens=batch_size,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
logger.info("Warmup decode 1/1...")
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=dummy_requests)

# doing one additional prefill outside the warmup_context seems to be
# necessary to have reasonable TTFT for the first prefill after warmup
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[add_dummy_request],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={add_dummy_request.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
Expand Down Expand Up @@ -430,7 +435,7 @@ def _cleanup_model_runner(self, request) -> None:
# Needed to clean up the data of model runner
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
# NOTE: this means no work to do
total_num_scheduled_tokens=0,
Expand Down Expand Up @@ -539,23 +544,31 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
]

# Set up dummy cached_requests for decode steps
cached_requests = [
CachedRequestData(
req_id=req.req_id,
resumed_from_preemption=False,
new_token_ids=[
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
], # placeholder token
new_block_ids=req.block_ids,
num_computed_tokens=req.num_computed_tokens,
) for req in dummy_requests
]
req_ids = []
new_token_ids = []
new_block_ids = []
num_computed_tokens = []
for req in dummy_requests:
req_ids.append(req.req_id)
new_token_ids.append([
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
]) # placeholder token
new_block_ids.append([req.block_ids])
num_computed_tokens.append(req.num_computed_tokens)

cached_request_data = CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=False,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)

# Set up scheduler_output for execute_model
scheduler_output = SchedulerOutput(
scheduled_new_reqs=dummy_requests,
scheduled_cached_reqs=[],
scheduled_cached_reqs=cached_request_data,
num_scheduled_tokens={i: prompt_len
for i in range(batch_size)},
total_num_scheduled_tokens=sum(prompt_len
Expand All @@ -574,7 +587,8 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
# The fixed size warmup needs to happen only in here
with _maybe_warmup_context():
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
cached_requests, num_decode_tokens)
cached_request_data,
num_decode_tokens)
self.perf_metrics.log("warmup 1 time",
time.time() - warmup_start_t,
batch_size=batch_size,
Expand All @@ -585,7 +599,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
logger.info("Warmup forward pass 2/2...")
warmup2_start_t = time.time()
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
cached_requests, num_decode_tokens)
cached_request_data, num_decode_tokens)

warmup_end_t = time.time()
warmup_total_t = warmup_end_t - warmup_start_t
Expand All @@ -604,17 +618,17 @@ def _warmup_model_forward_pass(
self,
scheduler_output: SchedulerOutput,
requests: list[NewRequestData],
cached_requests: list[CachedRequestData],
cached_request_data: CachedRequestData,
num_decode_tokens,
):
"""Handle a complete forward pass"""
scheduler_output.scheduled_new_reqs = requests
scheduler_output.scheduled_cached_reqs = []
scheduler_output.scheduled_cached_reqs = CachedRequestData.make_empty()
self.execute_model(scheduler_output) # Prefill

# Switch to cached requests to trigger decoding steps
scheduler_output.scheduled_new_reqs = []
scheduler_output.scheduled_cached_reqs = cached_requests
scheduler_output.scheduled_cached_reqs = cached_request_data
for _ in range(num_decode_tokens - 1):
self.execute_model(scheduler_output)

Expand Down
Loading