Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#include "ngram.h"

#include "trie.h"
#include <limits>
#include <stdexcept>
#include <string>

#include "trie.h"

namespace ngram {

Ngram::Ngram(size_t capacity, const Param& param) : param_(param) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once

#include "param.h"
#include "queue.h"
#include "result.h"
#include "trie.h"
#include <condition_variable>
#include <cstddef>
#include <cstdint>
Expand All @@ -8,11 +12,6 @@
#include <thread>
#include <vector>

#include "param.h"
#include "queue.h"
#include "result.h"
#include "trie.h"

namespace ngram {

class Ngram {
Expand Down
104 changes: 104 additions & 0 deletions python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#pragma once

#include <sgl_kernel/ffi.h>
#include <sgl_kernel/tensor.h>

#include <tvm/ffi/reflection/registry.h>

#include "ngram.h"
#include <cstdint>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <vector>

struct NgramCorpusObj : public tvm::ffi::Object {
public:
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object);
static constexpr bool _type_mutable = true;

NgramCorpusObj(
int64_t capacity,
int64_t max_trie_depth,
int64_t min_bfs_breadth,
int64_t max_bfs_breadth,
int64_t draft_token_num,
int64_t match_type) {
ngram::Param param;
param.enable = true;
param.enable_router_mode = false;
param.max_trie_depth = static_cast<size_t>(max_trie_depth);
param.min_bfs_breadth = static_cast<size_t>(min_bfs_breadth);
param.max_bfs_breadth = static_cast<size_t>(max_bfs_breadth);
param.draft_token_num = static_cast<size_t>(draft_token_num);
param.match_type = (match_type == 0) ? "BFS" : "PROB";
ngram_ = std::make_unique<ngram::Ngram>(static_cast<size_t>(capacity), param);
}

void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) {
auto* data = static_cast<const int32_t*>(tokens_flat.data_ptr());
auto* offs = static_cast<const int64_t*>(offsets.data_ptr());
int64_t batch_size = offsets.size(0) - 1;

std::vector<std::vector<int32_t>> tokens(batch_size);
for (int64_t i = 0; i < batch_size; ++i) {
tokens[i].assign(data + offs[i], data + offs[i + 1]);
}
ngram_->asyncInsert(std::move(tokens));
}

void batch_match(
const tvm::ffi::TensorView tokens_flat,
const tvm::ffi::TensorView offsets,
const tvm::ffi::TensorView out_tokens,
const tvm::ffi::TensorView out_mask) {
auto* data = static_cast<const int32_t*>(tokens_flat.data_ptr());
auto* offs = static_cast<const int64_t*>(offsets.data_ptr());
int64_t batch_size = offsets.size(0) - 1;

std::vector<std::vector<int32_t>> tokens(batch_size);
for (int64_t i = 0; i < batch_size; ++i) {
tokens[i].assign(data + offs[i], data + offs[i + 1]);
}

auto result = ngram_->batchMatch(tokens);

auto* out_tok = static_cast<int32_t*>(out_tokens.data_ptr());
auto* out_msk = static_cast<uint8_t*>(out_mask.data_ptr());
if (result.token.size() > static_cast<size_t>(out_tokens.size(0))) {
throw std::runtime_error(
"out_tokens buffer too small: " + std::to_string(out_tokens.size(0)) + " < " +
std::to_string(result.token.size()));
}
if (result.mask.size() > static_cast<size_t>(out_mask.size(0))) {
throw std::runtime_error(
"out_mask buffer too small: " + std::to_string(out_mask.size(0)) + " < " +
std::to_string(result.mask.size()));
}
std::memcpy(out_tok, result.token.data(), result.token.size() * sizeof(int32_t));
std::memcpy(out_msk, result.mask.data(), result.mask.size() * sizeof(uint8_t));
}

void synchronize() {
ngram_->synchronize();
}

void reset() {
ngram_->reset();
}

private:
std::unique_ptr<ngram::Ngram> ngram_;
};

void register_ngram_corpus() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<NgramCorpusObj>()
.def(refl::init<int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>(), "__init__")
.def("async_insert", &NgramCorpusObj::async_insert)
.def("batch_match", &NgramCorpusObj::batch_match)
.def("synchronize", &NgramCorpusObj::synchronize)
.def("reset", &NgramCorpusObj::reset);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(register_once, register_ngram_corpus);
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "param.h"
#include "result.h"
#include <cstddef>
#include <cstdint>
#include <functional>
Expand All @@ -10,9 +12,6 @@
#include <unordered_map>
#include <vector>

#include "param.h"
#include "result.h"

namespace ngram {

struct TrieNode {
Expand Down
88 changes: 88 additions & 0 deletions python/sglang/jit_kernel/ngram_corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import List, Tuple

import numpy as np
import torch
import tvm_ffi

from sglang.jit_kernel.utils import cache_once, load_jit

_MATCH_TYPE_MAP = {"BFS": 0, "PROB": 1}


def _to_csr(batch_tokens: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]:
flat = []
offsets = [0]
for seq in batch_tokens:
flat.extend(seq)
offsets.append(len(flat))
tokens_flat = torch.tensor(flat, dtype=torch.int32)
offsets_t = torch.tensor(offsets, dtype=torch.int64)
return tokens_flat, offsets_t


@cache_once
def get_ngram_corpus_cls():
module = load_jit(
"ngram_corpus",
cpp_files=[
"ngram_corpus/result.cpp",
"ngram_corpus/trie.cpp",
"ngram_corpus/ngram.cpp",
"ngram_corpus/ngram_corpus_ffi.cpp",
],
header_only=False,
)
module.register_once()

@tvm_ffi.register_object("sgl.NgramCorpus")
class NgramCorpusFFI(tvm_ffi.Object):
__slots__ = ("__dict__",)

def __init__(
self,
capacity: int,
max_trie_depth: int,
min_bfs_breadth: int,
max_bfs_breadth: int,
draft_token_num: int,
match_type: str,
) -> None:
mt = _MATCH_TYPE_MAP.get(match_type)
if mt is None:
raise ValueError(
f"Unknown match_type: '{match_type}'. Must be 'BFS' or 'PROB'."
)
self.__ffi_init__(
capacity,
max_trie_depth,
min_bfs_breadth,
max_bfs_breadth,
draft_token_num,
mt,
)
self._draft_token_num = draft_token_num

def insert(self, batch_tokens: List[List[int]]) -> None:
tokens_flat, offsets = _to_csr(batch_tokens)
self.async_insert(tokens_flat, offsets) # type: ignore

def match(
self,
batch_tokens: List[List[int]],
) -> Tuple[np.ndarray, np.ndarray]:
tokens_flat, offsets = _to_csr(batch_tokens)
batch_size = len(batch_tokens)
d = self._draft_token_num

out_tokens = torch.zeros(batch_size * d, dtype=torch.int32)
out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8)

self.batch_match(tokens_flat, offsets, out_tokens, out_mask) # type: ignore

return out_tokens.numpy().astype(np.int64), out_mask.numpy().astype(
np.int64
)

return NgramCorpusFFI
66 changes: 42 additions & 24 deletions python/sglang/jit_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def load_jit(
extra_include_paths: List[str] | None = None,
extra_dependencies: List[str] | None = None,
build_directory: str | None = None,
header_only: bool = True,
) -> Module:
"""
Loading a JIT module from C++/CUDA source files.
Expand Down Expand Up @@ -169,47 +170,64 @@ def load_jit(
:type extra_dependencies: List[str] | None
:param build_directory: The build directory for JIT compilation.
:type build_directory: str | None
:param header_only: Whether the module is header-only.
If true, apply the wrappers to export given class/functions.
Otherwise, we must export from C++/CUDA side.
:return: A just-in-time(JIT) compiled module.
:rtype: Module
"""

from tvm_ffi.cpp import load_inline
from tvm_ffi.cpp import load, load_inline

cpp_files = cpp_files or []
cuda_files = cuda_files or []
cpp_wrappers = cpp_wrappers or []
cuda_wrappers = cuda_wrappers or []
extra_cflags = extra_cflags or []
extra_cuda_cflags = extra_cuda_cflags or []
extra_ldflags = extra_ldflags or []
extra_include_paths = extra_include_paths or []

cpp_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cpp_files]
cuda_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cuda_files]

for dep in set(extra_dependencies or []):
if dep not in _REGISTERED_DEPENDENCIES:
raise ValueError(f"Dependency {dep} is not registered.")
extra_include_paths += _REGISTERED_DEPENDENCIES[dep]()

# include cpp files
cpp_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cpp_files]
cpp_sources = [f'#include "{path}"' for path in cpp_paths]
cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers]

# include cuda files
cuda_paths = [(KERNEL_PATH / "csrc" / f).resolve() for f in cuda_files]
cuda_sources = [f'#include "{path}"' for path in cuda_paths]
cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers]

with _jit_compile_context():
return load_inline(
"sgl_kernel_jit_" + "_".join(str(arg) for arg in args),
cpp_sources=cpp_sources,
cuda_sources=cuda_sources,
extra_cflags=DEFAULT_CFLAGS + extra_cflags,
extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags,
extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags,
extra_include_paths=DEFAULT_INCLUDE + extra_include_paths,
build_directory=build_directory,
)
module_name = "sgl_kernel_jit_" + "_".join(str(arg) for arg in args)
if header_only:
cpp_wrappers = cpp_wrappers or []
cuda_wrappers = cuda_wrappers or []
cpp_sources = [f'#include "{path}"' for path in cpp_files]
cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers]

# include cuda files
cuda_sources = [f'#include "{path}"' for path in cuda_files]
cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers]
with _jit_compile_context():
return load_inline(
module_name,
cpp_sources=cpp_sources,
cuda_sources=cuda_sources,
extra_cflags=DEFAULT_CFLAGS + extra_cflags,
extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags,
extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags,
extra_include_paths=DEFAULT_INCLUDE + extra_include_paths,
build_directory=build_directory,
)
else:
assert cpp_wrappers is None and cuda_wrappers is None
with _jit_compile_context():
return load(
module_name,
cpp_files=cpp_files,
cuda_files=cuda_files,
extra_cflags=DEFAULT_CFLAGS + extra_cflags,
extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags,
extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags,
extra_include_paths=DEFAULT_INCLUDE + extra_include_paths,
build_directory=build_directory,
)


@dataclass
Expand Down
Loading
Loading