[V1] Get input tokens from scheduler#13339
Conversation
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
cc @comaniac This PR seems to work correctly when using TP (or single GPU), but PP still generates gibberish outputs. |
|
This pull request has merge conflicts that must be resolved before it can be |
|
@comaniac @njhill @LiuXiaoxuanPKU I've update the PR with some simplification for spec decoding. PTAL. |
| request.num_tokens) | ||
| if num_scheduled_spec_tokens > 0: | ||
| scheduled_spec_decode_tokens[request.request_id] = ( | ||
| request.spec_token_ids[:num_scheduled_spec_tokens]) |
| # 1. Write spec_token_ids to input batch. | ||
| # Step 1. Get req indices that perform spec decode and repeat | ||
| # the req indices by the number of spec tokens. Note | ||
| # for requests that don't perform spec decode, the | ||
| # number of spec tokens is 0 and the req index is | ||
| # repeated 0 times. | ||
| # E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] | ||
| # spec_req_indices: [0, 0, 0, 2, 2, 4] | ||
| spec_req_indices = np.repeat(self.arange_np[:num_reqs], | ||
| num_spec_tokens_list) | ||
| # spec_offsets: offsets within each spec token list. | ||
| # E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here | ||
| spec_offsets = np.concatenate( | ||
| [self.arange_np[1:val + 1] for val in num_spec_tokens_list]) | ||
| # spec_seq_offsets: offsets within each sequence. | ||
| # E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2] | ||
| # after repeating: [1, 1, 1, 3, 3, 2] | ||
| # spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1] | ||
| # = [2, 3, 4, 4, 5, 3] | ||
| spec_seq_offsets = np.repeat( | ||
| self.input_batch.num_computed_tokens_cpu[:num_reqs], | ||
| num_spec_tokens_list) + spec_offsets | ||
| # cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3] | ||
| cumsums_spec_offsets = ( | ||
| spec_seq_offsets + | ||
| spec_req_indices * self.input_batch.token_ids_cpu.shape[1]) | ||
| cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( | ||
| torch.int64) | ||
| all_spec_token_ids = torch.tensor(all_spec_token_ids, | ||
| device="cpu", | ||
| dtype=self.input_ids_cpu.dtype) |
There was a problem hiding this comment.
This part can be skipped as we insert spec_token_ids into token_ids_cpu and treat them as regular input tokens.
| # Step 2. Write spec token ids to input_ids_cpu. | ||
| self.input_batch.token_ids_cpu_tensor.flatten().scatter_( | ||
| 0, cumsums_spec_offsets, all_spec_token_ids) | ||
|
|
||
| # 2. Get spec decode logits indices. | ||
| # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] | ||
| # cu_num_tokens: [4, 104, 107, 207, 209] | ||
| # num_spec_tokens_list: [3, 0, 2, 0, 1] | ||
| # num_sampled_tokens: [4, 1, 3, 1, 2] | ||
| # spec_decode_logits_indices: | ||
| # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] | ||
| num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) | ||
| num_sampled_tokens = num_spec_tokens_np + 1 | ||
| # logits_start_loc: [0, 103, 104, 206, 207] | ||
| logits_start_loc = cu_num_tokens - num_sampled_tokens | ||
| # [0, 103, 104, 206, 207] -> | ||
| # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] | ||
| logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) | ||
| # The following three lines: | ||
| # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] | ||
| # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] | ||
| cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) | ||
| # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] | ||
| # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] | ||
| cumsums_sampled_offsets = np.repeat( | ||
| cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) | ||
| # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
| # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] | ||
| # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] | ||
| total_num_sampled_tokens = num_sampled_tokens.sum() | ||
| sampled_arange = (self.arange_np[:total_num_sampled_tokens] - | ||
| cumsums_sampled_offsets) | ||
|
|
||
| # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> | ||
| # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] | ||
| spec_decode_logits_indices = logits_start_loc + sampled_arange |
There was a problem hiding this comment.
This part was moved to a separate method for better readability.
ywang96
left a comment
There was a problem hiding this comment.
Overall logic looks good to me but I left two comments - PTAL!
| assert all( | ||
| req_id is not None for req_id in | ||
| self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | ||
| req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) |
There was a problem hiding this comment.
I thought casting will be much faster? Is there a reason why you did the loop & append for req_ids instead?
There was a problem hiding this comment.
Agree.
Also since it's a list I think it's probably better to not preallocate to the max size like we do for tensors. The list can grow/shrink as needed. So IMO we could change that in the input batch (I actually did that in #13244). This way we can also just keep the type as List[str]. (not suggesting for this PR though...)
There was a problem hiding this comment.
Regarding this particular line, I think the proposed change would be slightly faster because the original code creates the list three times (two from req_ids[:num_reqs] and another from the list comprehension for all) while the proposed change creates only one.
Agreed with @njhill. All of these are quite hacky and unnecessarily complex. req_ids should be fixed by #13244.
njhill
left a comment
There was a problem hiding this comment.
looks great to me, it's also a nice simplification!
Only minor suggestions
| assert all( | ||
| req_id is not None for req_id in | ||
| self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | ||
| req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) |
There was a problem hiding this comment.
Agree.
Also since it's a list I think it's probably better to not preallocate to the max size like we do for tensors. The list can grow/shrink as needed. So IMO we could change that in the input batch (I actually did that in #13244). This way we can also just keep the type as List[str]. (not suggesting for this PR though...)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This PR changes the scheduler and model runner so that the model runner gets the input token IDs from the scheduler. This change is especially useful when the token IDs are not generated by the model runner (e.g., non-last ranks in PP).