diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/ngram.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp index 904782774916..0b90fa812f34 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp @@ -1,11 +1,10 @@ #include "ngram.h" +#include "trie.h" #include #include #include -#include "trie.h" - namespace ngram { Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/ngram.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h index 377b481ae3fe..fb1461d9ac92 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h @@ -1,5 +1,9 @@ #pragma once +#include "param.h" +#include "queue.h" +#include "result.h" +#include "trie.h" #include #include #include @@ -8,11 +12,6 @@ #include #include -#include "param.h" -#include "queue.h" -#include "result.h" -#include "trie.h" - namespace ngram { class Ngram { diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp new file mode 100644 index 000000000000..e1797e1fc0f3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp @@ -0,0 +1,104 @@ +#pragma once + +#include +#include + +#include + +#include "ngram.h" +#include +#include +#include +#include +#include + +struct NgramCorpusObj : public tvm::ffi::Object { + public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object); + static constexpr bool _type_mutable = true; + + NgramCorpusObj( + int64_t capacity, + int64_t max_trie_depth, + int64_t min_bfs_breadth, + int64_t max_bfs_breadth, + int64_t draft_token_num, + int64_t match_type) { + ngram::Param param; + param.enable = true; + param.enable_router_mode = false; + param.max_trie_depth = static_cast(max_trie_depth); + param.min_bfs_breadth = static_cast(min_bfs_breadth); + param.max_bfs_breadth = static_cast(max_bfs_breadth); + param.draft_token_num = static_cast(draft_token_num); + param.match_type = (match_type == 0) ? "BFS" : "PROB"; + ngram_ = std::make_unique(static_cast(capacity), param); + } + + void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { + auto* data = static_cast(tokens_flat.data_ptr()); + auto* offs = static_cast(offsets.data_ptr()); + int64_t batch_size = offsets.size(0) - 1; + + std::vector> tokens(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + tokens[i].assign(data + offs[i], data + offs[i + 1]); + } + ngram_->asyncInsert(std::move(tokens)); + } + + void batch_match( + const tvm::ffi::TensorView tokens_flat, + const tvm::ffi::TensorView offsets, + const tvm::ffi::TensorView out_tokens, + const tvm::ffi::TensorView out_mask) { + auto* data = static_cast(tokens_flat.data_ptr()); + auto* offs = static_cast(offsets.data_ptr()); + int64_t batch_size = offsets.size(0) - 1; + + std::vector> tokens(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + tokens[i].assign(data + offs[i], data + offs[i + 1]); + } + + auto result = ngram_->batchMatch(tokens); + + auto* out_tok = static_cast(out_tokens.data_ptr()); + auto* out_msk = static_cast(out_mask.data_ptr()); + if (result.token.size() > static_cast(out_tokens.size(0))) { + throw std::runtime_error( + "out_tokens buffer too small: " + std::to_string(out_tokens.size(0)) + " < " + + std::to_string(result.token.size())); + } + if (result.mask.size() > static_cast(out_mask.size(0))) { + throw std::runtime_error( + "out_mask buffer too small: " + std::to_string(out_mask.size(0)) + " < " + + std::to_string(result.mask.size())); + } + std::memcpy(out_tok, result.token.data(), result.token.size() * sizeof(int32_t)); + std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t)); + } + + void synchronize() { + ngram_->synchronize(); + } + + void reset() { + ngram_->reset(); + } + + private: + std::unique_ptr ngram_; +}; + +void register_ngram_corpus() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def(refl::init(), "__init__") + .def("async_insert", &NgramCorpusObj::async_insert) + .def("batch_match", &NgramCorpusObj::batch_match) + .def("synchronize", &NgramCorpusObj::synchronize) + .def("reset", &NgramCorpusObj::reset); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(register_once, register_ngram_corpus); diff --git a/python/sglang/srt/speculative/cpp_ngram/param.h b/python/sglang/jit_kernel/csrc/ngram_corpus/param.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/param.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/param.h diff --git a/python/sglang/srt/speculative/cpp_ngram/queue.h b/python/sglang/jit_kernel/csrc/ngram_corpus/queue.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/queue.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/queue.h diff --git a/python/sglang/srt/speculative/cpp_ngram/result.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/result.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp diff --git a/python/sglang/srt/speculative/cpp_ngram/result.h b/python/sglang/jit_kernel/csrc/ngram_corpus/result.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/result.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/result.h diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/trie.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/trie.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/trie.h index 41fd6e54ceb2..bd555597dd46 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h @@ -1,5 +1,7 @@ #pragma once +#include "param.h" +#include "result.h" #include #include #include @@ -10,9 +12,6 @@ #include #include -#include "param.h" -#include "result.h" - namespace ngram { struct TrieNode { diff --git a/python/sglang/jit_kernel/ngram_corpus.py b/python/sglang/jit_kernel/ngram_corpus.py new file mode 100644 index 000000000000..42b6babab0dd --- /dev/null +++ b/python/sglang/jit_kernel/ngram_corpus.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import List, Tuple + +import numpy as np +import torch +import tvm_ffi + +from sglang.jit_kernel.utils import cache_once, load_jit + +_MATCH_TYPE_MAP = {"BFS": 0, "PROB": 1} + + +def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: + flat = [] + offsets = [0] + for seq in batch_tokens: + flat.extend(seq) + offsets.append(len(flat)) + tokens_flat = torch.tensor(flat, dtype=torch.int32) + offsets_t = torch.tensor(offsets, dtype=torch.int64) + return tokens_flat, offsets_t + + +@cache_once +def get_ngram_corpus_cls(): + module = load_jit( + "ngram_corpus", + cpp_files=[ + "ngram_corpus/result.cpp", + "ngram_corpus/trie.cpp", + "ngram_corpus/ngram.cpp", + "ngram_corpus/ngram_corpus_ffi.cpp", + ], + header_only=False, + ) + module.register_once() + + @tvm_ffi.register_object("sgl.NgramCorpus") + class NgramCorpusFFI(tvm_ffi.Object): + __slots__ = ("__dict__",) + + def __init__( + self, + capacity: int, + max_trie_depth: int, + min_bfs_breadth: int, + max_bfs_breadth: int, + draft_token_num: int, + match_type: str, + ) -> None: + mt = _MATCH_TYPE_MAP.get(match_type) + if mt is None: + raise ValueError( + f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." + ) + self.__ffi_init__( + capacity, + max_trie_depth, + min_bfs_breadth, + max_bfs_breadth, + draft_token_num, + mt, + ) + self._draft_token_num = draft_token_num + + def insert(self, batch_tokens: List[List[int]]) -> None: + tokens_flat, offsets = _to_csr(batch_tokens) + self.async_insert(tokens_flat, offsets) # type: ignore + + def match( + self, + batch_tokens: List[List[int]], + ) -> Tuple[np.ndarray, np.ndarray]: + tokens_flat, offsets = _to_csr(batch_tokens) + batch_size = len(batch_tokens) + d = self._draft_token_num + + out_tokens = torch.zeros(batch_size * d, dtype=torch.int32) + out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8) + + self.batch_match(tokens_flat, offsets, out_tokens, out_mask) # type: ignore + + return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype( + np.int64 + ) + + return NgramCorpusFFI diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index 63a0ba99d041..ec3ebd5abc80 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -140,6 +140,7 @@ def load_jit( extra_include_paths: List[str] | None = None, extra_dependencies: List[str] | None = None, build_directory: str | None = None, + header_only: bool = True, ) -> Module: """ Loading a JIT module from C++/CUDA source files. @@ -169,47 +170,64 @@ def load_jit( :type extra_dependencies: List[str] | None :param build_directory: The build directory for JIT compilation. :type build_directory: str | None + :param header_only: Whether the module is header-only. + If true, apply the wrappers to export given class/functions. + Otherwise, we must export from C++/CUDA side. :return: A just-in-time(JIT) compiled module. :rtype: Module """ - from tvm_ffi.cpp import load_inline + from tvm_ffi.cpp import load, load_inline cpp_files = cpp_files or [] cuda_files = cuda_files or [] - cpp_wrappers = cpp_wrappers or [] - cuda_wrappers = cuda_wrappers or [] extra_cflags = extra_cflags or [] extra_cuda_cflags = extra_cuda_cflags or [] extra_ldflags = extra_ldflags or [] extra_include_paths = extra_include_paths or [] + cpp_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cpp_files] + cuda_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cuda_files] + for dep in set(extra_dependencies or []): if dep not in _REGISTERED_DEPENDENCIES: raise ValueError(f"Dependency {dep} is not registered.") extra_include_paths += _REGISTERED_DEPENDENCIES[dep]() - # include cpp files - cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files] - cpp_sources = [f'#include "{path}"' for path in cpp_paths] - cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] - - # include cuda files - cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files] - cuda_sources = [f'#include "{path}"' for path in cuda_paths] - cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] - - with _jit_compile_context(): - return load_inline( - "sgl_kernel_jit_" + "_".join(str(arg) for arg in args), - cpp_sources=cpp_sources, - cuda_sources=cuda_sources, - extra_cflags=DEFAULT_CFLAGS + extra_cflags, - extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, - extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, - extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, - build_directory=build_directory, - ) + module_name = "sgl_kernel_jit_" + "_".join(str(arg) for arg in args) + if header_only: + cpp_wrappers = cpp_wrappers or [] + cuda_wrappers = cuda_wrappers or [] + cpp_sources = [f'#include "{path}"' for path in cpp_files] + cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] + + # include cuda files + cuda_sources = [f'#include "{path}"' for path in cuda_files] + cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] + with _jit_compile_context(): + return load_inline( + module_name, + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + else: + assert cpp_wrappers is None and cuda_wrappers is None + with _jit_compile_context(): + return load( + module_name, + cpp_files=cpp_files, + cuda_files=cuda_files, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) @dataclass diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index f35e9acf95fe..0eb6bd71cbf7 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -1,25 +1,13 @@ # -*- coding: utf-8 -*- import logging -import os from typing import List, Tuple import numpy as np -from torch.utils.cpp_extension import load -logger = logging.getLogger(__name__) +from sglang.jit_kernel.ngram_corpus import get_ngram_corpus_cls -_abs_path = os.path.dirname(os.path.abspath(__file__)) -ngram_corpus_cpp = load( - name="ngram_corpus_cpp", - sources=[ - f"{_abs_path}/ngram_corpus_binding.cpp", - f"{_abs_path}/ngram.cpp", - f"{_abs_path}/trie.cpp", - f"{_abs_path}/result.cpp", - ], - extra_cflags=["-O3", "-std=c++20"], -) +logger = logging.getLogger(__name__) class NgramCorpus: @@ -31,30 +19,30 @@ def __init__( draft_token_num=8, match_type="BFS", capacity=1000000, - ): - param = ngram_corpus_cpp.Param() - param.max_trie_depth = max_trie_depth - param.min_bfs_breadth = min_bfs_breadth - param.max_bfs_breadth = max_bfs_breadth - param.draft_token_num = draft_token_num - param.match_type = match_type - self._ngram = ngram_corpus_cpp.Ngram(capacity, param) - + ) -> None: + cls = get_ngram_corpus_cls() + self._obj = cls( + capacity=capacity, + max_trie_depth=max_trie_depth, + min_bfs_breadth=min_bfs_breadth, + max_bfs_breadth=max_bfs_breadth, + draft_token_num=draft_token_num, + match_type=match_type, + ) self.default_mask = np.ones((1, 1), dtype=np.int64) self.draft_token_num = draft_token_num def batch_put(self, batch_tokens: List[List[int]]): - self._ngram.asyncInsert(batch_tokens) + self._obj.insert(batch_tokens) def synchronize(self): - self._ngram.synchronize() + self._obj.synchronize() # type: ignore def reset(self): - self._ngram.reset() + self._obj.reset() # type: ignore def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: - result = self._ngram.batchMatch(batch_tokens) - return np.array(result.token), np.array(result.mask) + return self._obj.match(batch_tokens) def leaf_paths_from_mask( self, tokens: List[int], tree_mask: List[List[int]] diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp deleted file mode 100644 index e632dfb3de59..000000000000 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include - -#include "ngram.h" - -PYBIND11_MODULE(ngram_corpus_cpp, m) { - using namespace ngram; - namespace py = pybind11; - m.doc() = ""; - - py::class_(m, "Ngram") - .def(py::init(), py::arg("capacity"), py::arg("param")) - .def("asyncInsert", &Ngram::asyncInsert, "") - .def("batchMatch", &Ngram::batchMatch, "") - .def("reset", &Ngram::reset, "") - .def("synchronize", &Ngram::synchronize, ""); - - py::class_(m, "Param") - .def(py::init<>()) - .def_readwrite("enable", &Param::enable) - .def_readwrite("enable_router_mode", &Param::enable_router_mode) - .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth) - .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth) - .def_readwrite("max_trie_depth", &Param::max_trie_depth) - .def_readwrite("draft_token_num", &Param::draft_token_num) - .def_readwrite("match_type", &Param::match_type) - .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num) - .def("get_draft_token_num", &Param::get_draft_token_num, "") - .def("parse", &Param::parse, "") - .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") - .def("detail", &Param::detail, ""); - - py::class_(m, "Result") - .def(py::init<>()) - .def_readwrite("token", &Result::token) - .def_readwrite("mask", &Result::mask) - .def("truncate", &Result::truncate); -} diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 93b2d77b5ac9..e8d9fc026beb 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -480,41 +480,39 @@ def test_full_budget_used(self): class TestTruncate(CustomTestCase): - """Verify the Result.truncate method via the Python binding.""" + """Verify truncation logic on batch_get output.""" def test_truncate_reduces_output(self): corpus = _make_corpus("BFS", draft_token_num=8) corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) - original_len = len(result.token) - self.assertEqual(original_len, 8) + ids, masks = corpus.batch_get([[1, 2, 3]]) + ids = ids.reshape(8) + self.assertEqual(len(ids), 8) - result.truncate(4) - self.assertEqual(len(result.token), 4) - self.assertEqual(len(result.mask), 4 * 4) + # Simulate truncate to 4 + trunc_n = 4 + trunc_ids = ids[:trunc_n] + self.assertEqual(len(trunc_ids), trunc_n) def test_truncate_preserves_mask_structure(self): corpus = _make_corpus("BFS", draft_token_num=8) corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) - full_ids = list(result.token) - full_mask = list(result.mask) - n = len(full_ids) + ids, masks = corpus.batch_get([[1, 2, 3]]) + n = 8 + full_mask = masks.reshape(n, n) - result_copy = corpus._ngram.batchMatch([[1, 2, 3]]) trunc_n = 4 - result_copy.truncate(trunc_n) - trunc_mask = list(result_copy.mask) + trunc_mask = full_mask[:trunc_n, :trunc_n] for i in range(trunc_n): for j in range(trunc_n): self.assertEqual( - trunc_mask[i * trunc_n + j], - full_mask[i * n + j], + trunc_mask[i, j], + full_mask[i, j], f"Mask mismatch at ({i},{j})", )