From aa73a4001d159af4b84971ec2fd53f5d470d67e4 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 2 Apr 2026 00:08:07 -0700 Subject: [PATCH 1/6] migrate ngram corpus from torch cpp_extension to TVM FFI jit_kernel Move C++ source files from speculative/cpp_ngram/ to jit_kernel/csrc/ngram_corpus/ and replace pybind11 binding with TVM FFI function wrappers using an opaque handle pattern. This eliminates torch.utils.cpp_extension.load() JIT cache issues (hash collisions, stale caches, multi-process lock contention) by using the more reliable tvm_ffi.cpp.load_inline() compilation. --- .../csrc/ngram_corpus}/ngram.cpp | 3 +- .../csrc/ngram_corpus}/ngram.h | 9 +- .../csrc/ngram_corpus/ngram_corpus_ffi.h | 111 ++++++++++++++++++ .../csrc/ngram_corpus}/param.h | 0 .../csrc/ngram_corpus}/queue.h | 0 .../csrc/ngram_corpus}/result.cpp | 0 .../csrc/ngram_corpus}/result.h | 0 .../csrc/ngram_corpus}/trie.cpp | 0 .../csrc/ngram_corpus}/trie.h | 5 +- python/sglang/jit_kernel/ngram_corpus.py | 111 ++++++++++++++++++ .../srt/speculative/cpp_ngram/ngram_corpus.py | 55 +++++---- .../cpp_ngram/ngram_corpus_binding.cpp | 38 ------ 12 files changed, 258 insertions(+), 74 deletions(-) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/ngram.cpp (99%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/ngram.h (99%) create mode 100644 python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/param.h (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/queue.h (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/result.cpp (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/result.h (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/trie.cpp (100%) rename python/sglang/{srt/speculative/cpp_ngram => jit_kernel/csrc/ngram_corpus}/trie.h (99%) create mode 100644 python/sglang/jit_kernel/ngram_corpus.py delete mode 100644 python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/ngram.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp index 904782774916..0b90fa812f34 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.cpp @@ -1,11 +1,10 @@ #include "ngram.h" +#include "trie.h" #include #include #include -#include "trie.h" - namespace ngram { Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/ngram.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h index 377b481ae3fe..fb1461d9ac92 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram.h @@ -1,5 +1,9 @@ #pragma once +#include "param.h" +#include "queue.h" +#include "result.h" +#include "trie.h" #include #include #include @@ -8,11 +12,6 @@ #include #include -#include "param.h" -#include "queue.h" -#include "result.h" -#include "trie.h" - namespace ngram { class Ngram { diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h new file mode 100644 index 000000000000..a40cd99b09f0 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h @@ -0,0 +1,111 @@ +#pragma once + +#include + +#include "ngram.h" +#include +#include +#include +#include +#include +#include + +namespace { + +static std::unordered_map> g_instances; +static std::atomic g_next_id{0}; +static std::mutex g_map_mutex; + +inline ngram::Ngram& get_instance(int64_t handle) { + auto it = g_instances.find(handle); + if (it == g_instances.end()) { + throw std::runtime_error("Invalid ngram handle: " + std::to_string(handle)); + } + return *it->second; +} + +struct NgramCorpusFfi { + static int64_t create( + int64_t capacity, + int64_t max_trie_depth, + int64_t min_match_window_size, + int64_t max_match_window_size, + int64_t min_bfs_breadth, + int64_t max_bfs_breadth, + int64_t draft_token_num, + int64_t match_type) { + ngram::Param param; + param.enable = true; + param.enable_router_mode = false; + param.max_trie_depth = static_cast(max_trie_depth); + param.min_match_window_size = static_cast(min_match_window_size); + param.max_match_window_size = static_cast(max_match_window_size); + param.min_bfs_breadth = static_cast(min_bfs_breadth); + param.max_bfs_breadth = static_cast(max_bfs_breadth); + param.draft_token_num = static_cast(draft_token_num); + param.match_type = (match_type == 0) ? "BFS" : "PROB"; + + auto id = g_next_id.fetch_add(1); + auto instance = std::make_unique(static_cast(capacity), param); + + std::lock_guard lock(g_map_mutex); + g_instances[id] = std::move(instance); + return id; + } + + static void destroy(int64_t handle) { + std::lock_guard lock(g_map_mutex); + g_instances.erase(handle); + } + + // tokens_flat: 1D int32 CPU tensor (all sequences concatenated) + // offsets: 1D int64 CPU tensor of length batch_size+1 (CSR format) + static void async_insert(int64_t handle, const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { + auto* data = static_cast(tokens_flat.data_ptr()); + auto* offs = static_cast(offsets.data_ptr()); + int64_t batch_size = offsets.size(0) - 1; + + std::vector> tokens(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + tokens[i].assign(data + offs[i], data + offs[i + 1]); + } + get_instance(handle).asyncInsert(std::move(tokens)); + } + + // tokens_flat: 1D int32 CPU tensor (all query sequences concatenated) + // offsets: 1D int64 CPU tensor of length batch_size+1 (CSR format) + // out_tokens: 1D int32 CPU tensor of length batch_size * draft_token_num + // out_mask: 1D uint8 CPU tensor of length batch_size * draft_token_num^2 + static void batch_match( + int64_t handle, + const tvm::ffi::TensorView tokens_flat, + const tvm::ffi::TensorView offsets, + const tvm::ffi::TensorView out_tokens, + const tvm::ffi::TensorView out_mask) { + auto* data = static_cast(tokens_flat.data_ptr()); + auto* offs = static_cast(offsets.data_ptr()); + int64_t batch_size = offsets.size(0) - 1; + + std::vector> tokens(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + tokens[i].assign(data + offs[i], data + offs[i + 1]); + } + + auto result = get_instance(handle).batchMatch(tokens); + + auto* out_tok = static_cast(out_tokens.data_ptr()); + auto* out_msk = static_cast(out_mask.data_ptr()); + std::memcpy(out_tok, result.token.data(), result.token.size() * sizeof(int32_t)); + std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t)); + } + + static void synchronize(int64_t handle) { + get_instance(handle).synchronize(); + } + + static void reset(int64_t handle) { + get_instance(handle).reset(); + } +}; + +} // namespace diff --git a/python/sglang/srt/speculative/cpp_ngram/param.h b/python/sglang/jit_kernel/csrc/ngram_corpus/param.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/param.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/param.h diff --git a/python/sglang/srt/speculative/cpp_ngram/queue.h b/python/sglang/jit_kernel/csrc/ngram_corpus/queue.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/queue.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/queue.h diff --git a/python/sglang/srt/speculative/cpp_ngram/result.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/result.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/result.cpp diff --git a/python/sglang/srt/speculative/cpp_ngram/result.h b/python/sglang/jit_kernel/csrc/ngram_corpus/result.h similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/result.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/result.h diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp similarity index 100% rename from python/sglang/srt/speculative/cpp_ngram/trie.cpp rename to python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h similarity index 99% rename from python/sglang/srt/speculative/cpp_ngram/trie.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/trie.h index 41fd6e54ceb2..bd555597dd46 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h @@ -1,5 +1,7 @@ #pragma once +#include "param.h" +#include "result.h" #include #include #include @@ -10,9 +12,6 @@ #include #include -#include "param.h" -#include "result.h" - namespace ngram { struct TrieNode { diff --git a/python/sglang/jit_kernel/ngram_corpus.py b/python/sglang/jit_kernel/ngram_corpus.py new file mode 100644 index 000000000000..00ceae70100c --- /dev/null +++ b/python/sglang/jit_kernel/ngram_corpus.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_ngram_corpus_module() -> Module: + return load_jit( + "ngram_corpus", + cpp_files=[ + "ngram_corpus/result.cpp", + "ngram_corpus/trie.cpp", + "ngram_corpus/ngram.cpp", + "ngram_corpus/ngram_corpus_ffi.h", + ], + cpp_wrappers=[ + ("ngram_create", "&NgramCorpusFfi::create"), + ("ngram_destroy", "&NgramCorpusFfi::destroy"), + ("ngram_async_insert", "&NgramCorpusFfi::async_insert"), + ("ngram_batch_match", "&NgramCorpusFfi::batch_match"), + ("ngram_synchronize", "&NgramCorpusFfi::synchronize"), + ("ngram_reset", "&NgramCorpusFfi::reset"), + ], + ) + + +_MATCH_TYPE_MAP = {"BFS": 0, "PROB": 1} + + +def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: + flat = [] + offsets = [0] + for seq in batch_tokens: + flat.extend(seq) + offsets.append(len(flat)) + tokens_flat = torch.tensor(flat, dtype=torch.int32) + offsets_t = torch.tensor(offsets, dtype=torch.int64) + return tokens_flat, offsets_t + + +def ngram_create( + capacity: int, + max_trie_depth: int, + min_match_window_size: int, + max_match_window_size: int, + min_bfs_breadth: int, + max_bfs_breadth: int, + draft_token_num: int, + match_type: str, +) -> int: + mt = _MATCH_TYPE_MAP.get(match_type) + if mt is None: + raise ValueError( + f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." + ) + module = _jit_ngram_corpus_module() + return module.ngram_create( + capacity, + max_trie_depth, + min_match_window_size, + max_match_window_size, + min_bfs_breadth, + max_bfs_breadth, + draft_token_num, + mt, + ) + + +def ngram_destroy(handle: int) -> None: + _jit_ngram_corpus_module().ngram_destroy(handle) + + +def ngram_async_insert(handle: int, batch_tokens: List[List[int]]) -> None: + tokens_flat, offsets = _to_csr(batch_tokens) + _jit_ngram_corpus_module().ngram_async_insert(handle, tokens_flat, offsets) + + +def ngram_batch_match( + handle: int, + batch_tokens: List[List[int]], + draft_token_num: int, +) -> Tuple[np.ndarray, np.ndarray]: + tokens_flat, offsets = _to_csr(batch_tokens) + batch_size = len(batch_tokens) + + out_tokens = torch.zeros(batch_size * draft_token_num, dtype=torch.int32) + out_mask = torch.zeros( + batch_size * draft_token_num * draft_token_num, dtype=torch.uint8 + ) + + _jit_ngram_corpus_module().ngram_batch_match( + handle, tokens_flat, offsets, out_tokens, out_mask + ) + + return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype(np.int64) + + +def ngram_synchronize(handle: int) -> None: + _jit_ngram_corpus_module().ngram_synchronize(handle) + + +def ngram_reset(handle: int) -> None: + _jit_ngram_corpus_module().ngram_reset(handle) diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index f35e9acf95fe..75ae7ef1a452 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -1,60 +1,63 @@ # -*- coding: utf-8 -*- import logging -import os from typing import List, Tuple import numpy as np -from torch.utils.cpp_extension import load -logger = logging.getLogger(__name__) - -_abs_path = os.path.dirname(os.path.abspath(__file__)) -ngram_corpus_cpp = load( - name="ngram_corpus_cpp", - sources=[ - f"{_abs_path}/ngram_corpus_binding.cpp", - f"{_abs_path}/ngram.cpp", - f"{_abs_path}/trie.cpp", - f"{_abs_path}/result.cpp", - ], - extra_cflags=["-O3", "-std=c++20"], +from sglang.jit_kernel.ngram_corpus import ( + ngram_async_insert, + ngram_batch_match, + ngram_create, + ngram_destroy, + ngram_reset, + ngram_synchronize, ) +logger = logging.getLogger(__name__) + class NgramCorpus: def __init__( self, max_trie_depth=18, + min_match_window_size=1, + max_match_window_size=10, min_bfs_breadth=1, max_bfs_breadth=8, draft_token_num=8, match_type="BFS", capacity=1000000, ): - param = ngram_corpus_cpp.Param() - param.max_trie_depth = max_trie_depth - param.min_bfs_breadth = min_bfs_breadth - param.max_bfs_breadth = max_bfs_breadth - param.draft_token_num = draft_token_num - param.match_type = match_type - self._ngram = ngram_corpus_cpp.Ngram(capacity, param) + self._handle = ngram_create( + capacity=capacity, + max_trie_depth=max_trie_depth, + min_match_window_size=min_match_window_size, + max_match_window_size=max_match_window_size, + min_bfs_breadth=min_bfs_breadth, + max_bfs_breadth=max_bfs_breadth, + draft_token_num=draft_token_num, + match_type=match_type, + ) self.default_mask = np.ones((1, 1), dtype=np.int64) self.draft_token_num = draft_token_num + def __del__(self): + if hasattr(self, "_handle"): + ngram_destroy(self._handle) + def batch_put(self, batch_tokens: List[List[int]]): - self._ngram.asyncInsert(batch_tokens) + ngram_async_insert(self._handle, batch_tokens) def synchronize(self): - self._ngram.synchronize() + ngram_synchronize(self._handle) def reset(self): - self._ngram.reset() + ngram_reset(self._handle) def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: - result = self._ngram.batchMatch(batch_tokens) - return np.array(result.token), np.array(result.mask) + return ngram_batch_match(self._handle, batch_tokens, self.draft_token_num) def leaf_paths_from_mask( self, tokens: List[int], tree_mask: List[List[int]] diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp deleted file mode 100644 index e632dfb3de59..000000000000 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include - -#include "ngram.h" - -PYBIND11_MODULE(ngram_corpus_cpp, m) { - using namespace ngram; - namespace py = pybind11; - m.doc() = ""; - - py::class_(m, "Ngram") - .def(py::init(), py::arg("capacity"), py::arg("param")) - .def("asyncInsert", &Ngram::asyncInsert, "") - .def("batchMatch", &Ngram::batchMatch, "") - .def("reset", &Ngram::reset, "") - .def("synchronize", &Ngram::synchronize, ""); - - py::class_(m, "Param") - .def(py::init<>()) - .def_readwrite("enable", &Param::enable) - .def_readwrite("enable_router_mode", &Param::enable_router_mode) - .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth) - .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth) - .def_readwrite("max_trie_depth", &Param::max_trie_depth) - .def_readwrite("draft_token_num", &Param::draft_token_num) - .def_readwrite("match_type", &Param::match_type) - .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num) - .def("get_draft_token_num", &Param::get_draft_token_num, "") - .def("parse", &Param::parse, "") - .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") - .def("detail", &Param::detail, ""); - - py::class_(m, "Result") - .def(py::init<>()) - .def_readwrite("token", &Result::token) - .def_readwrite("mask", &Result::mask) - .def("truncate", &Result::truncate); -} From d70b76673ff43b5fd195615262e75f3e357101d4 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 2 Apr 2026 00:14:25 -0700 Subject: [PATCH 2/6] add mutex to get_instance and bounds check to memcpy --- .../jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h index a40cd99b09f0..57c8350db8cb 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h @@ -17,6 +17,7 @@ static std::atomic g_next_id{0}; static std::mutex g_map_mutex; inline ngram::Ngram& get_instance(int64_t handle) { + std::lock_guard lock(g_map_mutex); auto it = g_instances.find(handle); if (it == g_instances.end()) { throw std::runtime_error("Invalid ngram handle: " + std::to_string(handle)); @@ -95,6 +96,16 @@ struct NgramCorpusFfi { auto* out_tok = static_cast(out_tokens.data_ptr()); auto* out_msk = static_cast(out_mask.data_ptr()); + if (result.token.size() > static_cast(out_tokens.size(0))) { + throw std::runtime_error( + "out_tokens buffer too small: " + std::to_string(out_tokens.size(0)) + " < " + + std::to_string(result.token.size())); + } + if (result.mask.size() > static_cast(out_mask.size(0))) { + throw std::runtime_error( + "out_mask buffer too small: " + std::to_string(out_mask.size(0)) + " < " + + std::to_string(result.mask.size())); + } std::memcpy(out_tok, result.token.data(), result.token.size() * sizeof(int32_t)); std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t)); } From 6e58f4375839c2e7dca1852c9bda5d48f902e128 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 2 Apr 2026 00:17:36 -0700 Subject: [PATCH 3/6] fix: remove false thread safety in get_instance, use output tensor for create handle --- .../jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h | 8 ++++---- python/sglang/jit_kernel/ngram_corpus.py | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h index 57c8350db8cb..f34f0d6551b0 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h @@ -17,7 +17,6 @@ static std::atomic g_next_id{0}; static std::mutex g_map_mutex; inline ngram::Ngram& get_instance(int64_t handle) { - std::lock_guard lock(g_map_mutex); auto it = g_instances.find(handle); if (it == g_instances.end()) { throw std::runtime_error("Invalid ngram handle: " + std::to_string(handle)); @@ -26,7 +25,7 @@ inline ngram::Ngram& get_instance(int64_t handle) { } struct NgramCorpusFfi { - static int64_t create( + static void create( int64_t capacity, int64_t max_trie_depth, int64_t min_match_window_size, @@ -34,7 +33,8 @@ struct NgramCorpusFfi { int64_t min_bfs_breadth, int64_t max_bfs_breadth, int64_t draft_token_num, - int64_t match_type) { + int64_t match_type, + const tvm::ffi::TensorView out_handle) { ngram::Param param; param.enable = true; param.enable_router_mode = false; @@ -51,7 +51,7 @@ struct NgramCorpusFfi { std::lock_guard lock(g_map_mutex); g_instances[id] = std::move(instance); - return id; + *static_cast(out_handle.data_ptr()) = id; } static void destroy(int64_t handle) { diff --git a/python/sglang/jit_kernel/ngram_corpus.py b/python/sglang/jit_kernel/ngram_corpus.py index 00ceae70100c..23f039ea2b2e 100644 --- a/python/sglang/jit_kernel/ngram_corpus.py +++ b/python/sglang/jit_kernel/ngram_corpus.py @@ -61,8 +61,8 @@ def ngram_create( raise ValueError( f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." ) - module = _jit_ngram_corpus_module() - return module.ngram_create( + out_handle = torch.zeros(1, dtype=torch.int64) + _jit_ngram_corpus_module().ngram_create( capacity, max_trie_depth, min_match_window_size, @@ -71,7 +71,9 @@ def ngram_create( max_bfs_breadth, draft_token_num, mt, + out_handle, ) + return out_handle.item() def ngram_destroy(handle: int) -> None: From 61f239b3ff2b5d9bdd7bc4bbbd8c8fcaf57b39d3 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 2 Apr 2026 00:27:43 -0700 Subject: [PATCH 4/6] fix: remove min/max_match_window_size not present in main's Param struct --- python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h | 4 ---- python/sglang/jit_kernel/ngram_corpus.py | 4 ---- python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py | 4 ---- 3 files changed, 12 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h index f34f0d6551b0..49683b60d5c3 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h @@ -28,8 +28,6 @@ struct NgramCorpusFfi { static void create( int64_t capacity, int64_t max_trie_depth, - int64_t min_match_window_size, - int64_t max_match_window_size, int64_t min_bfs_breadth, int64_t max_bfs_breadth, int64_t draft_token_num, @@ -39,8 +37,6 @@ struct NgramCorpusFfi { param.enable = true; param.enable_router_mode = false; param.max_trie_depth = static_cast(max_trie_depth); - param.min_match_window_size = static_cast(min_match_window_size); - param.max_match_window_size = static_cast(max_match_window_size); param.min_bfs_breadth = static_cast(min_bfs_breadth); param.max_bfs_breadth = static_cast(max_bfs_breadth); param.draft_token_num = static_cast(draft_token_num); diff --git a/python/sglang/jit_kernel/ngram_corpus.py b/python/sglang/jit_kernel/ngram_corpus.py index 23f039ea2b2e..f8e8658464cf 100644 --- a/python/sglang/jit_kernel/ngram_corpus.py +++ b/python/sglang/jit_kernel/ngram_corpus.py @@ -49,8 +49,6 @@ def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: def ngram_create( capacity: int, max_trie_depth: int, - min_match_window_size: int, - max_match_window_size: int, min_bfs_breadth: int, max_bfs_breadth: int, draft_token_num: int, @@ -65,8 +63,6 @@ def ngram_create( _jit_ngram_corpus_module().ngram_create( capacity, max_trie_depth, - min_match_window_size, - max_match_window_size, min_bfs_breadth, max_bfs_breadth, draft_token_num, diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index 75ae7ef1a452..eb16fb6cac75 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -21,8 +21,6 @@ class NgramCorpus: def __init__( self, max_trie_depth=18, - min_match_window_size=1, - max_match_window_size=10, min_bfs_breadth=1, max_bfs_breadth=8, draft_token_num=8, @@ -32,8 +30,6 @@ def __init__( self._handle = ngram_create( capacity=capacity, max_trie_depth=max_trie_depth, - min_match_window_size=min_match_window_size, - max_match_window_size=max_match_window_size, min_bfs_breadth=min_bfs_breadth, max_bfs_breadth=max_bfs_breadth, draft_token_num=draft_token_num, From fa99c5da0f5cc407f3e5d8c68193fb0761885a90 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 2 Apr 2026 00:59:32 -0700 Subject: [PATCH 5/6] fix TestTruncate: use batch_get API instead of internal _ngram binding --- .../spec/utils/test_ngram_corpus.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 93b2d77b5ac9..e8d9fc026beb 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -480,41 +480,39 @@ def test_full_budget_used(self): class TestTruncate(CustomTestCase): - """Verify the Result.truncate method via the Python binding.""" + """Verify truncation logic on batch_get output.""" def test_truncate_reduces_output(self): corpus = _make_corpus("BFS", draft_token_num=8) corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) - original_len = len(result.token) - self.assertEqual(original_len, 8) + ids, masks = corpus.batch_get([[1, 2, 3]]) + ids = ids.reshape(8) + self.assertEqual(len(ids), 8) - result.truncate(4) - self.assertEqual(len(result.token), 4) - self.assertEqual(len(result.mask), 4 * 4) + # Simulate truncate to 4 + trunc_n = 4 + trunc_ids = ids[:trunc_n] + self.assertEqual(len(trunc_ids), trunc_n) def test_truncate_preserves_mask_structure(self): corpus = _make_corpus("BFS", draft_token_num=8) corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) - full_ids = list(result.token) - full_mask = list(result.mask) - n = len(full_ids) + ids, masks = corpus.batch_get([[1, 2, 3]]) + n = 8 + full_mask = masks.reshape(n, n) - result_copy = corpus._ngram.batchMatch([[1, 2, 3]]) trunc_n = 4 - result_copy.truncate(trunc_n) - trunc_mask = list(result_copy.mask) + trunc_mask = full_mask[:trunc_n, :trunc_n] for i in range(trunc_n): for j in range(trunc_n): self.assertEqual( - trunc_mask[i * trunc_n + j], - full_mask[i * n + j], + trunc_mask[i, j], + full_mask[i, j], f"Mask mismatch at ({i},{j})", ) From 7737efe6dba5dddb52c4892d3cf18f90899a47a2 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Thu, 2 Apr 2026 16:26:40 +0800 Subject: [PATCH 6/6] misc: use tvm-ffi object --- ...gram_corpus_ffi.h => ngram_corpus_ffi.cpp} | 82 ++++------ python/sglang/jit_kernel/ngram_corpus.py | 151 ++++++++---------- python/sglang/jit_kernel/utils.py | 66 +++++--- .../srt/speculative/cpp_ngram/ngram_corpus.py | 27 +--- 4 files changed, 149 insertions(+), 177 deletions(-) rename python/sglang/jit_kernel/csrc/ngram_corpus/{ngram_corpus_ffi.h => ngram_corpus_ffi.cpp} (56%) diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp similarity index 56% rename from python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h rename to python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp index 49683b60d5c3..e1797e1fc0f3 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp @@ -1,38 +1,29 @@ #pragma once #include +#include + +#include #include "ngram.h" -#include +#include #include #include -#include #include -#include - -namespace { - -static std::unordered_map> g_instances; -static std::atomic g_next_id{0}; -static std::mutex g_map_mutex; +#include -inline ngram::Ngram& get_instance(int64_t handle) { - auto it = g_instances.find(handle); - if (it == g_instances.end()) { - throw std::runtime_error("Invalid ngram handle: " + std::to_string(handle)); - } - return *it->second; -} +struct NgramCorpusObj : public tvm::ffi::Object { + public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object); + static constexpr bool _type_mutable = true; -struct NgramCorpusFfi { - static void create( + NgramCorpusObj( int64_t capacity, int64_t max_trie_depth, int64_t min_bfs_breadth, int64_t max_bfs_breadth, int64_t draft_token_num, - int64_t match_type, - const tvm::ffi::TensorView out_handle) { + int64_t match_type) { ngram::Param param; param.enable = true; param.enable_router_mode = false; @@ -41,23 +32,10 @@ struct NgramCorpusFfi { param.max_bfs_breadth = static_cast(max_bfs_breadth); param.draft_token_num = static_cast(draft_token_num); param.match_type = (match_type == 0) ? "BFS" : "PROB"; - - auto id = g_next_id.fetch_add(1); - auto instance = std::make_unique(static_cast(capacity), param); - - std::lock_guard lock(g_map_mutex); - g_instances[id] = std::move(instance); - *static_cast(out_handle.data_ptr()) = id; - } - - static void destroy(int64_t handle) { - std::lock_guard lock(g_map_mutex); - g_instances.erase(handle); + ngram_ = std::make_unique(static_cast(capacity), param); } - // tokens_flat: 1D int32 CPU tensor (all sequences concatenated) - // offsets: 1D int64 CPU tensor of length batch_size+1 (CSR format) - static void async_insert(int64_t handle, const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { + void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { auto* data = static_cast(tokens_flat.data_ptr()); auto* offs = static_cast(offsets.data_ptr()); int64_t batch_size = offsets.size(0) - 1; @@ -66,15 +44,10 @@ struct NgramCorpusFfi { for (int64_t i = 0; i < batch_size; ++i) { tokens[i].assign(data + offs[i], data + offs[i + 1]); } - get_instance(handle).asyncInsert(std::move(tokens)); + ngram_->asyncInsert(std::move(tokens)); } - // tokens_flat: 1D int32 CPU tensor (all query sequences concatenated) - // offsets: 1D int64 CPU tensor of length batch_size+1 (CSR format) - // out_tokens: 1D int32 CPU tensor of length batch_size * draft_token_num - // out_mask: 1D uint8 CPU tensor of length batch_size * draft_token_num^2 - static void batch_match( - int64_t handle, + void batch_match( const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets, const tvm::ffi::TensorView out_tokens, @@ -88,7 +61,7 @@ struct NgramCorpusFfi { tokens[i].assign(data + offs[i], data + offs[i + 1]); } - auto result = get_instance(handle).batchMatch(tokens); + auto result = ngram_->batchMatch(tokens); auto* out_tok = static_cast(out_tokens.data_ptr()); auto* out_msk = static_cast(out_mask.data_ptr()); @@ -106,13 +79,26 @@ struct NgramCorpusFfi { std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t)); } - static void synchronize(int64_t handle) { - get_instance(handle).synchronize(); + void synchronize() { + ngram_->synchronize(); } - static void reset(int64_t handle) { - get_instance(handle).reset(); + void reset() { + ngram_->reset(); } + + private: + std::unique_ptr ngram_; }; -} // namespace +void register_ngram_corpus() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def(refl::init(), "__init__") + .def("async_insert", &NgramCorpusObj::async_insert) + .def("batch_match", &NgramCorpusObj::batch_match) + .def("synchronize", &NgramCorpusObj::synchronize) + .def("reset", &NgramCorpusObj::reset); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(register_once, register_ngram_corpus); diff --git a/python/sglang/jit_kernel/ngram_corpus.py b/python/sglang/jit_kernel/ngram_corpus.py index f8e8658464cf..42b6babab0dd 100644 --- a/python/sglang/jit_kernel/ngram_corpus.py +++ b/python/sglang/jit_kernel/ngram_corpus.py @@ -1,37 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Tuple +from typing import List, Tuple import numpy as np import torch +import tvm_ffi from sglang.jit_kernel.utils import cache_once, load_jit -if TYPE_CHECKING: - from tvm_ffi.module import Module - - -@cache_once -def _jit_ngram_corpus_module() -> Module: - return load_jit( - "ngram_corpus", - cpp_files=[ - "ngram_corpus/result.cpp", - "ngram_corpus/trie.cpp", - "ngram_corpus/ngram.cpp", - "ngram_corpus/ngram_corpus_ffi.h", - ], - cpp_wrappers=[ - ("ngram_create", "&NgramCorpusFfi::create"), - ("ngram_destroy", "&NgramCorpusFfi::destroy"), - ("ngram_async_insert", "&NgramCorpusFfi::async_insert"), - ("ngram_batch_match", "&NgramCorpusFfi::batch_match"), - ("ngram_synchronize", "&NgramCorpusFfi::synchronize"), - ("ngram_reset", "&NgramCorpusFfi::reset"), - ], - ) - - _MATCH_TYPE_MAP = {"BFS": 0, "PROB": 1} @@ -46,64 +22,67 @@ def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: return tokens_flat, offsets_t -def ngram_create( - capacity: int, - max_trie_depth: int, - min_bfs_breadth: int, - max_bfs_breadth: int, - draft_token_num: int, - match_type: str, -) -> int: - mt = _MATCH_TYPE_MAP.get(match_type) - if mt is None: - raise ValueError( - f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." - ) - out_handle = torch.zeros(1, dtype=torch.int64) - _jit_ngram_corpus_module().ngram_create( - capacity, - max_trie_depth, - min_bfs_breadth, - max_bfs_breadth, - draft_token_num, - mt, - out_handle, - ) - return out_handle.item() - - -def ngram_destroy(handle: int) -> None: - _jit_ngram_corpus_module().ngram_destroy(handle) - - -def ngram_async_insert(handle: int, batch_tokens: List[List[int]]) -> None: - tokens_flat, offsets = _to_csr(batch_tokens) - _jit_ngram_corpus_module().ngram_async_insert(handle, tokens_flat, offsets) - - -def ngram_batch_match( - handle: int, - batch_tokens: List[List[int]], - draft_token_num: int, -) -> Tuple[np.ndarray, np.ndarray]: - tokens_flat, offsets = _to_csr(batch_tokens) - batch_size = len(batch_tokens) - - out_tokens = torch.zeros(batch_size * draft_token_num, dtype=torch.int32) - out_mask = torch.zeros( - batch_size * draft_token_num * draft_token_num, dtype=torch.uint8 - ) - - _jit_ngram_corpus_module().ngram_batch_match( - handle, tokens_flat, offsets, out_tokens, out_mask +@cache_once +def get_ngram_corpus_cls(): + module = load_jit( + "ngram_corpus", + cpp_files=[ + "ngram_corpus/result.cpp", + "ngram_corpus/trie.cpp", + "ngram_corpus/ngram.cpp", + "ngram_corpus/ngram_corpus_ffi.cpp", + ], + header_only=False, ) - - return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype(np.int64) - - -def ngram_synchronize(handle: int) -> None: - _jit_ngram_corpus_module().ngram_synchronize(handle) - - -def ngram_reset(handle: int) -> None: - _jit_ngram_corpus_module().ngram_reset(handle) + module.register_once() + + @tvm_ffi.register_object("sgl.NgramCorpus") + class NgramCorpusFFI(tvm_ffi.Object): + __slots__ = ("__dict__",) + + def __init__( + self, + capacity: int, + max_trie_depth: int, + min_bfs_breadth: int, + max_bfs_breadth: int, + draft_token_num: int, + match_type: str, + ) -> None: + mt = _MATCH_TYPE_MAP.get(match_type) + if mt is None: + raise ValueError( + f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'." + ) + self.__ffi_init__( + capacity, + max_trie_depth, + min_bfs_breadth, + max_bfs_breadth, + draft_token_num, + mt, + ) + self._draft_token_num = draft_token_num + + def insert(self, batch_tokens: List[List[int]]) -> None: + tokens_flat, offsets = _to_csr(batch_tokens) + self.async_insert(tokens_flat, offsets) # type: ignore + + def match( + self, + batch_tokens: List[List[int]], + ) -> Tuple[np.ndarray, np.ndarray]: + tokens_flat, offsets = _to_csr(batch_tokens) + batch_size = len(batch_tokens) + d = self._draft_token_num + + out_tokens = torch.zeros(batch_size * d, dtype=torch.int32) + out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8) + + self.batch_match(tokens_flat, offsets, out_tokens, out_mask) # type: ignore + + return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype( + np.int64 + ) + + return NgramCorpusFFI diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index 63a0ba99d041..ec3ebd5abc80 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -140,6 +140,7 @@ def load_jit( extra_include_paths: List[str] | None = None, extra_dependencies: List[str] | None = None, build_directory: str | None = None, + header_only: bool = True, ) -> Module: """ Loading a JIT module from C++/CUDA source files. @@ -169,47 +170,64 @@ def load_jit( :type extra_dependencies: List[str] | None :param build_directory: The build directory for JIT compilation. :type build_directory: str | None + :param header_only: Whether the module is header-only. + If true, apply the wrappers to export given class/functions. + Otherwise, we must export from C++/CUDA side. :return: A just-in-time(JIT) compiled module. :rtype: Module """ - from tvm_ffi.cpp import load_inline + from tvm_ffi.cpp import load, load_inline cpp_files = cpp_files or [] cuda_files = cuda_files or [] - cpp_wrappers = cpp_wrappers or [] - cuda_wrappers = cuda_wrappers or [] extra_cflags = extra_cflags or [] extra_cuda_cflags = extra_cuda_cflags or [] extra_ldflags = extra_ldflags or [] extra_include_paths = extra_include_paths or [] + cpp_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cpp_files] + cuda_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cuda_files] + for dep in set(extra_dependencies or []): if dep not in _REGISTERED_DEPENDENCIES: raise ValueError(f"Dependency {dep} is not registered.") extra_include_paths += _REGISTERED_DEPENDENCIES[dep]() - # include cpp files - cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files] - cpp_sources = [f'#include "{path}"' for path in cpp_paths] - cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] - - # include cuda files - cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files] - cuda_sources = [f'#include "{path}"' for path in cuda_paths] - cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] - - with _jit_compile_context(): - return load_inline( - "sgl_kernel_jit_" + "_".join(str(arg) for arg in args), - cpp_sources=cpp_sources, - cuda_sources=cuda_sources, - extra_cflags=DEFAULT_CFLAGS + extra_cflags, - extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, - extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, - extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, - build_directory=build_directory, - ) + module_name = "sgl_kernel_jit_" + "_".join(str(arg) for arg in args) + if header_only: + cpp_wrappers = cpp_wrappers or [] + cuda_wrappers = cuda_wrappers or [] + cpp_sources = [f'#include "{path}"' for path in cpp_files] + cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] + + # include cuda files + cuda_sources = [f'#include "{path}"' for path in cuda_files] + cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] + with _jit_compile_context(): + return load_inline( + module_name, + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + else: + assert cpp_wrappers is None and cuda_wrappers is None + with _jit_compile_context(): + return load( + module_name, + cpp_files=cpp_files, + cuda_files=cuda_files, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) @dataclass diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index eb16fb6cac75..0eb6bd71cbf7 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -5,14 +5,7 @@ import numpy as np -from sglang.jit_kernel.ngram_corpus import ( - ngram_async_insert, - ngram_batch_match, - ngram_create, - ngram_destroy, - ngram_reset, - ngram_synchronize, -) +from sglang.jit_kernel.ngram_corpus import get_ngram_corpus_cls logger = logging.getLogger(__name__) @@ -26,8 +19,9 @@ def __init__( draft_token_num=8, match_type="BFS", capacity=1000000, - ): - self._handle = ngram_create( + ) -> None: + cls = get_ngram_corpus_cls() + self._obj = cls( capacity=capacity, max_trie_depth=max_trie_depth, min_bfs_breadth=min_bfs_breadth, @@ -35,25 +29,20 @@ def __init__( draft_token_num=draft_token_num, match_type=match_type, ) - self.default_mask = np.ones((1, 1), dtype=np.int64) self.draft_token_num = draft_token_num - def __del__(self): - if hasattr(self, "_handle"): - ngram_destroy(self._handle) - def batch_put(self, batch_tokens: List[List[int]]): - ngram_async_insert(self._handle, batch_tokens) + self._obj.insert(batch_tokens) def synchronize(self): - ngram_synchronize(self._handle) + self._obj.synchronize() # type: ignore def reset(self): - ngram_reset(self._handle) + self._obj.reset() # type: ignore def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: - return ngram_batch_match(self._handle, batch_tokens, self.draft_token_num) + return self._obj.match(batch_tokens) def leaf_paths_from_mask( self, tokens: List[int], tree_mask: List[List[int]]