-
-
Notifications
You must be signed in to change notification settings - Fork 16.7k
[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step #6338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
comaniac
merged 32 commits into
vllm-project:main
from
neuralmagic:prepare_inputs_on_gpu
Jul 17, 2024
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
cea2a11
tmp commit
alexm-redhat 9f0b0f8
sync
alexm-redhat 1e8ecf8
sync
alexm-redhat f7c67ee
start
alexm-redhat 02617c4
tmp
alexm-redhat f2af7c3
sync
alexm-redhat 1d831a2
works
alexm-redhat aa9747e
correct
alexm-redhat baeae6a
add cuda graphs
alexm-redhat de941c7
sync
alexm-redhat 2ea329d
skip logprobs
alexm-redhat 2c9eefc
sync
alexm-redhat 73b1879
sync
alexm-redhat 2d2adf1
cleanups
alexm-redhat 7d426a9
fix rebase bugs
alexm-redhat 50a1edf
Cody's refactor proposal
alexm-redhat e1ef1f4
Cody's refactor proposal
alexm-redhat e8e10f2
sync
alexm-redhat 2231012
format
alexm-redhat 4e97a8a
fix bug
alexm-redhat a90a085
restore test
alexm-redhat 89fd609
sync
alexm-redhat f0bd7ef
Cody review fixes
alexm-redhat 3ab5e9f
Fix acceptance rate to 100% for the fallback case
alexm-redhat c6eacc8
format
alexm-redhat 5626bfc
Cody's review comments
alexm-redhat 2aabb31
restore multistep test
alexm-redhat 5cf3f59
Cade's comments
alexm-redhat 917bea6
Cade and Cody comments
alexm-redhat a4968d3
fix bug
alexm-redhat a0d2384
Cody's review
alexm-redhat e5f4265
add Cody's mock test
alexm-redhat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| /* | ||
| * The goal of this GPU kernel is to advance input tensors on the GPU directly | ||
| * PR: https://github.com/vllm-project/vllm/pull/6338 | ||
| * Current restrictions: | ||
| * 1. Specialized for DraftModelRunner | ||
| * 2. Supports flash_attn only | ||
| */ | ||
|
|
||
| #include "advance_step.cuh" | ||
|
|
||
| namespace prepare_inputs { | ||
|
|
||
| // | ||
| template <int const num_threads> | ||
| __global__ void advance_step_kernel(int num_seqs, int num_queries, | ||
| int block_size, long* input_tokens_ptr, | ||
| long const* sampled_token_ids_ptr, | ||
| long* input_positions_ptr, | ||
| int* seq_lens_ptr, long* slot_mapping_ptr, | ||
| int const* block_tables_ptr, | ||
| int64_t const block_tables_stride) { | ||
| int num_query_blocks = div_ceil(num_queries, num_threads); | ||
|
|
||
| if (blockIdx.x >= num_query_blocks) { | ||
| return; | ||
| } | ||
|
|
||
| int cur_query_id = blockIdx.x * num_threads + threadIdx.x; | ||
|
|
||
| if (cur_query_id >= num_queries) { | ||
| return; | ||
| } | ||
|
|
||
| // Update input_tokens | ||
| input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; | ||
|
|
||
| int seq_len = seq_lens_ptr[cur_query_id]; | ||
| int next_seq_len = seq_len + 1; | ||
| int next_input_pos = next_seq_len - 1; | ||
|
|
||
| // Update seq_lens | ||
| seq_lens_ptr[cur_query_id] = next_seq_len; | ||
| // Update input_positions | ||
| input_positions_ptr[cur_query_id] = next_input_pos; | ||
|
|
||
| int const* seq_block_tables_ptr = | ||
| block_tables_ptr + block_tables_stride * cur_query_id; | ||
|
|
||
| int block_index = next_input_pos / block_size; | ||
| int block_offset = next_input_pos % block_size; | ||
|
|
||
| int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; | ||
| // Update slot_mapping | ||
| slot_mapping_ptr[cur_query_id] = slot_num; | ||
| } | ||
|
|
||
| inline void verify_tensor(std::string const& name, torch::Tensor& t, | ||
| int64_t const size_0, int64_t const size_1, | ||
| c10::ScalarType const type) { | ||
|
comaniac marked this conversation as resolved.
Outdated
|
||
| bool size_0_cond = true; | ||
| if (size_0 != -1) { | ||
| size_0_cond = t.size(0) == size_0; | ||
| } | ||
|
|
||
| bool size_1_cond = true; | ||
| if (size_1 != -1) { | ||
| size_1_cond = t.size(1) == size_1; | ||
| } | ||
|
|
||
| bool is_contiguous = t.is_contiguous(); | ||
| bool same_type = t.dtype() == type; | ||
|
|
||
| bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; | ||
| if (!pass) { | ||
| TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), | ||
| " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), | ||
| " is not as expected: shape = [", size_0, ", ", size_1, | ||
| "], type = ", type); | ||
| } | ||
| } | ||
|
|
||
| void advance_step(int num_seqs, int num_queries, int block_size, | ||
| torch::Tensor& input_tokens, // type: long | ||
| torch::Tensor& sampled_token_ids, // type: long | ||
| torch::Tensor& input_positions, // type: long | ||
| torch::Tensor& seq_lens, // type: int | ||
| torch::Tensor& slot_mapping, // type: long | ||
| torch::Tensor& block_tables) { // type: int | ||
|
|
||
| if (logging) { | ||
| printf("advance_step:\n"); | ||
| printf(" num_seqs = %d\n", num_seqs); | ||
| printf(" num_queries = %d\n", num_queries); | ||
| printf(" block_size = %d\n", block_size); | ||
| } | ||
| // Verify all tensors | ||
| verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); | ||
| verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, | ||
| at::kLong); | ||
| verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); | ||
| verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); | ||
| verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); | ||
| verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); | ||
|
|
||
| int dev = sampled_token_ids.get_device(); | ||
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); | ||
|
|
||
| int blocks; | ||
| cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); | ||
|
|
||
| advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>( | ||
| num_seqs, num_queries, block_size, | ||
| reinterpret_cast<long*>(input_tokens.data_ptr()), | ||
| reinterpret_cast<long const*>(sampled_token_ids.data_ptr()), | ||
| reinterpret_cast<long*>(input_positions.data_ptr()), | ||
| reinterpret_cast<int*>(seq_lens.data_ptr()), | ||
| reinterpret_cast<long*>(slot_mapping.data_ptr()), | ||
| reinterpret_cast<int const*>(block_tables.data_ptr()), | ||
| block_tables.stride(0)); | ||
| } | ||
|
|
||
| } // namespace prepare_inputs | ||
|
|
||
| void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, | ||
| torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, | ||
| torch::Tensor& input_positions, torch::Tensor& seq_lens, | ||
| torch::Tensor& slot_mapping, torch::Tensor& block_tables) { | ||
| prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, | ||
| sampled_token_ids, input_positions, seq_lens, | ||
| slot_mapping, block_tables); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| #pragma once | ||
|
|
||
| #include <torch/all.h> | ||
|
|
||
| #include <ATen/cuda/CUDAContext.h> | ||
| #include <c10/cuda/CUDAGuard.h> | ||
| #include <cuda.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_runtime.h> | ||
| #include <iostream> | ||
|
|
||
| namespace prepare_inputs { | ||
|
|
||
| static constexpr int max_threads = 256; | ||
| static constexpr bool logging = false; | ||
|
|
||
| constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } | ||
|
|
||
| } // namespace prepare_inputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.