Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--speculative-ngram-min-bfs-breadth` | The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `1` | Type: int |
| `--speculative-ngram-max-bfs-breadth` | The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `10` | Type: int |
| `--speculative-ngram-match-type` | Ngram tree-building mode. `BFS` selects recency-based expansion and `PROB` selects frequency-based expansion. This setting is forwarded to the ngram cache implementation. | `BFS` | `BFS`, `PROB` |
| `--speculative-ngram-branch-length` | The branch length for ngram speculative decoding. | `18` | Type: int |
| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` | Type: int |
| `--speculative-ngram-capacity` | The cache capacity for ngram speculative decoding. | `10000000` | Type: int |

## Multi-layer Eagle speculative decoding
Expand Down
4 changes: 2 additions & 2 deletions docs/advanced_features/speculative_decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ Enable it with:
| `--speculative-ngram-min-bfs-breadth` | Minimum BFS breadth. | `1` |
| `--speculative-ngram-max-bfs-breadth` | Maximum BFS breadth. | `10` |
| `--speculative-ngram-match-type` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion. | `"BFS"` |
| `--speculative-ngram-branch-length` | How many recent tokens to insert into the cache. | `18` |
| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` |
| `--speculative-ngram-capacity` | Cache capacity (number of entries). | `10,000,000` |

Notes:
Expand Down Expand Up @@ -469,7 +469,7 @@ Below is a comprehensive list of all speculative decoding parameters available i
| `--speculative-ngram-min-bfs-breadth` | `int` | `1` | Minimum BFS breadth |
| `--speculative-ngram-max-bfs-breadth` | `int` | `10` | Maximum BFS breadth |
| `--speculative-ngram-match-type` | `str` | `"BFS"` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion |
| `--speculative-ngram-branch-length` | `int` | `18` | Recent tokens to insert into cache |
| `--speculative-ngram-max-trie-depth` | `int` | `18` | Max trie depth for ngram speculative decoding |
| `--speculative-ngram-capacity` | `int` | `10,000,000` | Cache capacity |

### Environment variables
Expand Down
2 changes: 1 addition & 1 deletion docs/platforms/ascend_npu_support_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ click [Server Arguments](https://docs.sglang.io/advanced_features/server_argumen
| `--speculative-ngram-`<br/>`min-bfs-breadth` | `1` | Type: int | Experimental |
| `--speculative-ngram-`<br/>`max-bfs-breadth` | `10` | Type: int | Experimental |
| `--speculative-ngram-`<br/>`match-type` | `BFS` | `BFS`,<br/> `PROB` | Experimental. `BFS` uses recency-based expansion; `PROB` uses frequency-based expansion. |
| `--speculative-ngram-`<br/>`branch-length` | `18` | Type: int | Experimental |
| `--speculative-ngram-`<br/>`max-trie-depth` | `18` | Type: int | Experimental |
| `--speculative-ngram-`<br/>`capacity` | `10000000` | Type: int | Experimental |

## Expert parallelism
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class ServerArgs:
speculative_ngram_min_bfs_breadth: int = 1
speculative_ngram_max_bfs_breadth: int = 10
speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
speculative_ngram_branch_length: int = 18
speculative_ngram_max_trie_depth: int = 18
speculative_ngram_capacity: int = 10 * 1000 * 1000
enable_multi_layer_eagle: bool = False

Expand Down Expand Up @@ -4765,10 +4765,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The match type for cache tree.",
)
parser.add_argument(
"--speculative-ngram-branch-length",
"--speculative-ngram-max-trie-depth",
type=int,
default=ServerArgs.speculative_ngram_branch_length,
help="The branch length for ngram speculative decoding.",
default=ServerArgs.speculative_ngram_max_trie_depth,
help="The max trie depth for ngram speculative decoding.",
)
parser.add_argument(
"--speculative-ngram-capacity",
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/speculative/cpp_ngram/ngram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
namespace ngram {

Ngram::Ngram(size_t capacity, const Param& param) : param_(param) {
if (!(param_.branch_length > 1)) {
if (!(param_.max_trie_depth > 1)) {
throw std::runtime_error(
"param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
"param_.max_trie_depth must be greater than 1, current value: " + std::to_string(param_.max_trie_depth));
}
if (!(param_.min_match_window_size > 0)) {
throw std::runtime_error(
Expand All @@ -26,11 +26,11 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) {
std::to_string(param_.min_match_window_size) +
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
}
if (!(param_.max_match_window_size < param_.branch_length)) {
if (!(param_.max_match_window_size < param_.max_trie_depth)) {
throw std::runtime_error(
"max_match_window_size must be less than branch_length, current "
"max_match_window_size must be less than max_trie_depth, current "
"max_match_window_size: " +
std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
std::to_string(param_.max_match_window_size) + ", max_trie_depth: " + std::to_string(param_.max_trie_depth));
}
if (!(param_.min_bfs_breadth > 0)) {
throw std::runtime_error(
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class NgramCorpus:
def __init__(
self,
branch_length=18,
max_trie_depth=18,
min_match_window_size=1,
max_match_window_size=10,
min_bfs_breadth=1,
Expand All @@ -35,7 +35,7 @@ def __init__(
capacity=1000000,
):
param = ngram_corpus_cpp.Param()
param.branch_length = branch_length
param.max_trie_depth = max_trie_depth
param.min_match_window_size = min_match_window_size
param.max_match_window_size = max_match_window_size
param.min_bfs_breadth = min_bfs_breadth
Expand Down Expand Up @@ -131,7 +131,7 @@ def debug_result(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
]
corpus = NgramCorpus(branch_length=12, draft_token_num=8)
corpus = NgramCorpus(max_trie_depth=12, draft_token_num=8)
corpus.batch_put(token_ids)

corpus.synchronize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ PYBIND11_MODULE(ngram_corpus_cpp, m) {
.def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
.def_readwrite("min_match_window_size", &Param::min_match_window_size)
.def_readwrite("max_match_window_size", &Param::max_match_window_size)
.def_readwrite("branch_length", &Param::branch_length)
.def_readwrite("max_trie_depth", &Param::max_trie_depth)
.def_readwrite("draft_token_num", &Param::draft_token_num)
.def_readwrite("match_type", &Param::match_type)
.def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/cpp_ngram/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct Param {
size_t max_bfs_breadth;
size_t min_match_window_size;
size_t max_match_window_size;
size_t branch_length;
size_t max_trie_depth;
size_t draft_token_num;
std::string match_type;

Expand Down Expand Up @@ -109,7 +109,7 @@ struct Param {
ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
<< ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
<< ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
<< ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
<< ", max_trie_depth = " << max_trie_depth << ", draft_token_num = " << draft_token_num
<< ", match_type = " << match_type;
ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/speculative/cpp_ngram/trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Trie::Trie(size_t capacity, const Param& param) : param_(param) {
void Trie::insert(const int32_t* tokens, size_t len) {
for (size_t i = 0; i + param_.min_match_window_size < len; ++i) {
auto start = tokens + i;
auto end = start + std::min(len - i, param_.branch_length);
auto end = start + std::min(len - i, param_.max_trie_depth);

if (static_cast<size_t>(end - start) > free_node_count_) {
squeeze(end - start - free_node_count_);
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/speculative/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self.tp_rank = tp_rank
self.page_size = server_args.page_size
self.draft_token_num: int = server_args.speculative_num_draft_tokens
self.branch_length: int = server_args.speculative_ngram_branch_length
self.max_trie_depth: int = server_args.speculative_ngram_max_trie_depth
self.max_match_window_size: int = (
server_args.speculative_ngram_max_match_window_size
)
Expand All @@ -57,7 +57,7 @@ def __init__(
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
match_type=server_args.speculative_ngram_match_type,
capacity=server_args.speculative_ngram_capacity,
branch_length=server_args.speculative_ngram_branch_length,
max_trie_depth=server_args.speculative_ngram_max_trie_depth,
draft_token_num=server_args.speculative_num_draft_tokens,
)

Expand Down Expand Up @@ -209,7 +209,7 @@ def _update_ngram_corpus(self, batch: ScheduleBatch):
# put_ids = req.origin_input_ids + req.output_ids
# else:
put_ids = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.branch_length
req.origin_input_ids, req.output_ids, self.max_trie_depth
)
batch_tokens.append(put_ids)
self.ngram_corpus.batch_put(batch_tokens)
Expand Down
16 changes: 8 additions & 8 deletions test/registered/spec/utils/test_ngram_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def _make_corpus(match_type="BFS", **kwargs):
defaults = dict(
branch_length=12,
max_trie_depth=12,
min_match_window_size=1,
max_match_window_size=10,
min_bfs_breadth=1,
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_small_capacity_does_not_crash(self):

def test_eviction_preserves_recent(self):
corpus = _make_corpus(
"BFS", capacity=500, branch_length=6, max_match_window_size=5
"BFS", capacity=500, max_trie_depth=6, max_match_window_size=5
)

old_seq = list(range(1000, 1050))
Expand Down Expand Up @@ -358,7 +358,7 @@ def test_repeated_insert_promotes_token(self):
max_bfs_breadth=1,
min_bfs_breadth=1,
max_match_window_size=3,
branch_length=5,
max_trie_depth=5,
)
corpus.batch_put([[1, 2, 3, 10, 11]])
corpus.synchronize()
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_most_recent_insert_selected(self):
max_bfs_breadth=1,
min_bfs_breadth=1,
max_match_window_size=3,
branch_length=5,
max_trie_depth=5,
)
corpus.batch_put([[1, 2, 3, 10, 11]])
corpus.synchronize()
Expand Down Expand Up @@ -433,10 +433,10 @@ def test_single_token_query(self):


class TestLongContext(CustomTestCase):
"""Verify behavior when query context exceeds branch_length."""
"""Verify behavior when query context exceeds max_trie_depth."""

def test_context_longer_than_branch_length(self):
corpus = _make_corpus("BFS", branch_length=6, max_match_window_size=5)
def test_context_longer_than_max_trie_depth(self):
corpus = _make_corpus("BFS", max_trie_depth=6, max_match_window_size=5)
seq = list(range(1, 20))
corpus.batch_put([seq])
corpus.synchronize()
Expand Down Expand Up @@ -539,7 +539,7 @@ class TestSqueezeEvictsOld(CustomTestCase):

def test_old_data_evicted(self):
corpus = _make_corpus(
"BFS", capacity=150, branch_length=6, max_match_window_size=5
"BFS", capacity=150, max_trie_depth=6, max_match_window_size=5
)

old_seq = list(range(5000, 5030))
Expand Down
Loading