-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Migrate ngram corpus from torch cpp_extension to TVM FFI jit_kernel #21920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
aa73a40
migrate ngram corpus from torch cpp_extension to TVM FFI jit_kernel
hnyls2002 d70b766
add mutex to get_instance and bounds check to memcpy
hnyls2002 6e58f43
fix: remove false thread safety in get_instance, use output tensor fo…
hnyls2002 61f239b
fix: remove min/max_match_window_size not present in main's Param struct
hnyls2002 fa99c5d
fix TestTruncate: use batch_get API instead of internal _ngram binding
hnyls2002 7737efe
misc: use tvm-ffi object
DarkSharpness File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
3 changes: 1 addition & 2 deletions
3
...glang/srt/speculative/cpp_ngram/ngram.cpp → ...ng/jit_kernel/csrc/ngram_corpus/ngram.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| #pragma once | ||
|
|
||
| #include <sgl_kernel/ffi.h> | ||
| #include <sgl_kernel/tensor.h> | ||
|
|
||
| #include <tvm/ffi/reflection/registry.h> | ||
|
|
||
| #include "ngram.h" | ||
| #include <cstdint> | ||
| #include <cstring> | ||
| #include <memory> | ||
| #include <stdexcept> | ||
| #include <vector> | ||
|
|
||
| struct NgramCorpusObj : public tvm::ffi::Object { | ||
| public: | ||
| TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object); | ||
| static constexpr bool _type_mutable = true; | ||
|
|
||
| NgramCorpusObj( | ||
| int64_t capacity, | ||
| int64_t max_trie_depth, | ||
| int64_t min_bfs_breadth, | ||
| int64_t max_bfs_breadth, | ||
| int64_t draft_token_num, | ||
| int64_t match_type) { | ||
| ngram::Param param; | ||
| param.enable = true; | ||
| param.enable_router_mode = false; | ||
| param.max_trie_depth = static_cast<size_t>(max_trie_depth); | ||
| param.min_bfs_breadth = static_cast<size_t>(min_bfs_breadth); | ||
| param.max_bfs_breadth = static_cast<size_t>(max_bfs_breadth); | ||
| param.draft_token_num = static_cast<size_t>(draft_token_num); | ||
| param.match_type = (match_type == 0) ? "BFS" : "PROB"; | ||
| ngram_ = std::make_unique<ngram::Ngram>(static_cast<size_t>(capacity), param); | ||
| } | ||
|
|
||
| void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { | ||
| auto* data = static_cast<const int32_t*>(tokens_flat.data_ptr()); | ||
| auto* offs = static_cast<const int64_t*>(offsets.data_ptr()); | ||
| int64_t batch_size = offsets.size(0) - 1; | ||
|
|
||
| std::vector<std::vector<int32_t>> tokens(batch_size); | ||
| for (int64_t i = 0; i < batch_size; ++i) { | ||
| tokens[i].assign(data + offs[i], data + offs[i + 1]); | ||
| } | ||
| ngram_->asyncInsert(std::move(tokens)); | ||
| } | ||
|
|
||
| void batch_match( | ||
| const tvm::ffi::TensorView tokens_flat, | ||
| const tvm::ffi::TensorView offsets, | ||
| const tvm::ffi::TensorView out_tokens, | ||
| const tvm::ffi::TensorView out_mask) { | ||
| auto* data = static_cast<const int32_t*>(tokens_flat.data_ptr()); | ||
| auto* offs = static_cast<const int64_t*>(offsets.data_ptr()); | ||
| int64_t batch_size = offsets.size(0) - 1; | ||
|
|
||
| std::vector<std::vector<int32_t>> tokens(batch_size); | ||
| for (int64_t i = 0; i < batch_size; ++i) { | ||
| tokens[i].assign(data + offs[i], data + offs[i + 1]); | ||
| } | ||
|
|
||
| auto result = ngram_->batchMatch(tokens); | ||
|
|
||
| auto* out_tok = static_cast<int32_t*>(out_tokens.data_ptr()); | ||
| auto* out_msk = static_cast<uint8_t*>(out_mask.data_ptr()); | ||
| if (result.token.size() > static_cast<size_t>(out_tokens.size(0))) { | ||
| throw std::runtime_error( | ||
| "out_tokens buffer too small: " + std::to_string(out_tokens.size(0)) + " < " + | ||
| std::to_string(result.token.size())); | ||
| } | ||
| if (result.mask.size() > static_cast<size_t>(out_mask.size(0))) { | ||
| throw std::runtime_error( | ||
| "out_mask buffer too small: " + std::to_string(out_mask.size(0)) + " < " + | ||
| std::to_string(result.mask.size())); | ||
| } | ||
| std::memcpy(out_tok, result.token.data(), result.token.size() * sizeof(int32_t)); | ||
| std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t)); | ||
| } | ||
|
|
||
| void synchronize() { | ||
| ngram_->synchronize(); | ||
| } | ||
|
|
||
| void reset() { | ||
| ngram_->reset(); | ||
| } | ||
|
|
||
| private: | ||
| std::unique_ptr<ngram::Ngram> ngram_; | ||
| }; | ||
|
|
||
| void register_ngram_corpus() { | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::ObjectDef<NgramCorpusObj>() | ||
| .def(refl::init<int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>(), "__init__") | ||
| .def("async_insert", &NgramCorpusObj::async_insert) | ||
| .def("batch_match", &NgramCorpusObj::batch_match) | ||
| .def("synchronize", &NgramCorpusObj::synchronize) | ||
| .def("reset", &NgramCorpusObj::reset); | ||
| } | ||
|
|
||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(register_once, register_ngram_corpus); | ||
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import List, Tuple | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import tvm_ffi | ||
|
|
||
| from sglang.jit_kernel.utils import cache_once, load_jit | ||
|
|
||
| _MATCH_TYPE_MAP = {"BFS": 0, "PROB": 1} | ||
|
|
||
|
|
||
| def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| flat = [] | ||
| offsets = [0] | ||
| for seq in batch_tokens: | ||
| flat.extend(seq) | ||
| offsets.append(len(flat)) | ||
| tokens_flat = torch.tensor(flat, dtype=torch.int32) | ||
| offsets_t = torch.tensor(offsets, dtype=torch.int64) | ||
| return tokens_flat, offsets_t | ||
|
|
||
|
|
||
| @cache_once | ||
| def get_ngram_corpus_cls(): | ||
| module = load_jit( | ||
| "ngram_corpus", | ||
| cpp_files=[ | ||
| "ngram_corpus/result.cpp", | ||
| "ngram_corpus/trie.cpp", | ||
| "ngram_corpus/ngram.cpp", | ||
| "ngram_corpus/ngram_corpus_ffi.cpp", | ||
| ], | ||
| header_only=False, | ||
| ) | ||
| module.register_once() | ||
|
|
||
| @tvm_ffi.register_object("sgl.NgramCorpus") | ||
| class NgramCorpusFFI(tvm_ffi.Object): | ||
| __slots__ = ("__dict__",) | ||
|
|
||
| def __init__( | ||
| self, | ||
| capacity: int, | ||
| max_trie_depth: int, | ||
| min_bfs_breadth: int, | ||
| max_bfs_breadth: int, | ||
| draft_token_num: int, | ||
| match_type: str, | ||
| ) -> None: | ||
| mt = _MATCH_TYPE_MAP.get(match_type) | ||
| if mt is None: | ||
| raise ValueError( | ||
| f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." | ||
| ) | ||
| self.__ffi_init__( | ||
| capacity, | ||
| max_trie_depth, | ||
| min_bfs_breadth, | ||
| max_bfs_breadth, | ||
| draft_token_num, | ||
| mt, | ||
| ) | ||
| self._draft_token_num = draft_token_num | ||
|
|
||
| def insert(self, batch_tokens: List[List[int]]) -> None: | ||
| tokens_flat, offsets = _to_csr(batch_tokens) | ||
| self.async_insert(tokens_flat, offsets) # type: ignore | ||
|
|
||
| def match( | ||
| self, | ||
| batch_tokens: List[List[int]], | ||
| ) -> Tuple[np.ndarray, np.ndarray]: | ||
| tokens_flat, offsets = _to_csr(batch_tokens) | ||
| batch_size = len(batch_tokens) | ||
| d = self._draft_token_num | ||
|
|
||
| out_tokens = torch.zeros(batch_size * d, dtype=torch.int32) | ||
| out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8) | ||
|
|
||
| self.batch_match(tokens_flat, offsets, out_tokens, out_mask) # type: ignore | ||
|
|
||
| return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype( | ||
| np.int64 | ||
| ) | ||
|
|
||
| return NgramCorpusFFI |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.