-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 #7756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| #include "pytorch_extension_utils_rocm.h" | ||
| #endif | ||
|
|
||
| typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| // parent_list [bs, topk * (depth - 1) + 1)] | ||
| // selected_index [bs, draft_token_num - 1] | ||
| // verified_seq_len [bs] | ||
|
|
@@ -40,7 +42,8 @@ __global__ void build_tree_efficient( | |
| int64_t* retrive_next_sibling, | ||
| int topk, | ||
| int depth, | ||
| int draft_token_num) { | ||
| int draft_token_num, | ||
| int tree_mask_mode) { | ||
| int bid = blockIdx.x; | ||
| int tid = threadIdx.x; | ||
|
|
||
|
|
@@ -52,7 +55,13 @@ __global__ void build_tree_efficient( | |
| seq_tree_idx += verified_seq_len[i] * draft_token_num; | ||
| } | ||
| int seq_len = verified_seq_len[bid]; | ||
| int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; | ||
| int token_tree_idx; | ||
| if (tree_mask_mode == FULL_MASK) { | ||
| token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; | ||
| } else { | ||
|
Comment on lines
+59
to
+61
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1; | ||
| } | ||
| tree_mask[token_tree_idx - 1] = true; | ||
| for (int i = 0; i < draft_token_num - 1; i++) { | ||
| tree_mask[token_tree_idx + i] = false; | ||
| } | ||
|
|
@@ -124,26 +133,38 @@ void build_tree_kernel_efficient( | |
| at::Tensor retrive_next_sibling, | ||
| int64_t topk, | ||
| int64_t depth, | ||
| int64_t draft_token_num) { | ||
| int64_t draft_token_num, | ||
| int64_t tree_mask_mode) { | ||
| // TODO (ying) check shape | ||
| // TODO (ying) check type | ||
| int bs = parent_list.size(0); | ||
| dim3 grid(bs); | ||
| dim3 block(draft_token_num); | ||
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
|
||
| build_tree_efficient<<<grid, block, 0, stream>>>( | ||
| static_cast<int64_t*>(parent_list.data_ptr()), | ||
| static_cast<int64_t*>(selected_index.data_ptr()), | ||
| static_cast<int64_t*>(verified_seq_len.data_ptr()), | ||
| static_cast<bool*>(tree_mask.data_ptr()), | ||
| static_cast<int64_t*>(positions.data_ptr()), | ||
| static_cast<int64_t*>(retrive_index.data_ptr()), | ||
| static_cast<int64_t*>(retrive_next_token.data_ptr()), | ||
| static_cast<int64_t*>(retrive_next_sibling.data_ptr()), | ||
| int32_t(topk), | ||
| int32_t(depth), | ||
| int32_t(draft_token_num)); | ||
| if (tree_mask_mode == QLEN_ONLY_BITPACKING) { | ||
| size_t num_bytes_per_item = 1; | ||
| if (draft_token_num > 16) { | ||
| num_bytes_per_item = 4; | ||
| } else if (draft_token_num > 8) { | ||
| num_bytes_per_item = 2; | ||
| } | ||
| throw std::runtime_error("Not implemented"); | ||
|
Comment on lines
+145
to
+152
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } else { | ||
| build_tree_efficient<<<grid, block, 0, stream>>>( | ||
| static_cast<int64_t*>(parent_list.data_ptr()), | ||
| static_cast<int64_t*>(selected_index.data_ptr()), | ||
| static_cast<int64_t*>(verified_seq_len.data_ptr()), | ||
| static_cast<bool*>(tree_mask.data_ptr()), | ||
| static_cast<int64_t*>(positions.data_ptr()), | ||
| static_cast<int64_t*>(retrive_index.data_ptr()), | ||
| static_cast<int64_t*>(retrive_next_token.data_ptr()), | ||
| static_cast<int64_t*>(retrive_next_sibling.data_ptr()), | ||
| int32_t(topk), | ||
| int32_t(depth), | ||
| int32_t(draft_token_num), | ||
| int32_t(tree_mask_mode)); | ||
| } | ||
| } | ||
|
|
||
| template <typename IdType, typename IdType2> | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for
packed_dtype_idxcan result in an index out of bounds forpacked_dtypes. Ifnum_verify_tokensis large,packed_dtype_idxcan be out of bounds, leading to anIndexError. Consider adding a check to ensurepacked_dtype_idxis within the valid range.