From fa7fdfb302111c6acd51e5202a96666c1acd891b Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Thu, 31 Jul 2025 15:34:30 +0800 Subject: [PATCH 1/8] support multithread compute bitmask for spec decode --- python/llguidance/_lib.pyi | 15 ++++ python/llguidance/numpy.py | 14 ++++ python/llguidance/torch.py | 11 +++ python/torch_tests/test_matcher.py | 41 ++++++++++- python_ext/src/llmatcher.rs | 108 +++++++++++++++++++++++++++++ 5 files changed, 188 insertions(+), 1 deletion(-) diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index d7fdc311..0599fc8b 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -538,6 +538,21 @@ class LLExecutor: Prefer to use fill_next_token_bitmask_par(), which wraps this. """ + def unsafe_compute_mask_ptr_with_draft_token( + self, + interpreters: List[Tuple[LLMatcher, int, List[int]]], + tgt_pointer: int, + one_mask_byte_size: int, + trg_batch_size: int, + ) -> None: + """ + Used for speculative decoding. + Compute the token mask directly into memory at the specified pointer. + For each matcher, provide the index of the target mask and a list of draft tokens. + If index is K, the memory will be written at tgt_pointer + K * one_mask_byte_size, + where K < trg_batch_size. + Memory has to have size trg_batch_size * one_mask_byte_size. + """ class JsonCompileOptions(TypedDict, total=False): # defaults to "," diff --git a/python/llguidance/numpy.py b/python/llguidance/numpy.py index 21452a6a..414c2c51 100644 --- a/python/llguidance/numpy.py +++ b/python/llguidance/numpy.py @@ -66,3 +66,17 @@ def fill_next_token_bitmask_par(executor: LLExecutor, batch, vocab = bitmask.shape assert bitmask.flags["C_CONTIGUOUS"], "Mask must be contiguous" executor.unsafe_compute_mask_ptr(matchers, bitmask.ctypes.data, vocab * 4, batch) + + +def fill_next_token_bitmask_par_with_draft_tokens(executor: LLExecutor, + matchers: List[Tuple[LLMatcher, int, List[int]]], + bitmask: NDArray[np.int32]) -> None: + """ + Compute the token mask directly into the specified array. + For each matcher, provide the index of the target mask. + """ + assert bitmask.dtype == np.int32, "Mask must be int32" + assert bitmask.ndim == 2, "Mask must be 2D" + batch, vocab = bitmask.shape + assert bitmask.flags["C_CONTIGUOUS"], "Mask must be contiguous" + executor.unsafe_compute_mask_ptr_with_draft_token(matchers, bitmask.ctypes.data, vocab * 4, batch) diff --git a/python/llguidance/torch.py b/python/llguidance/torch.py index 2dea212d..f09eaae0 100644 --- a/python/llguidance/torch.py +++ b/python/llguidance/torch.py @@ -66,3 +66,14 @@ def fill_next_token_bitmask_par(executor: LLExecutor, assert bitmask.is_contiguous(), "Mask must be contiguous" executor.unsafe_compute_mask_ptr(matchers, bitmask.data_ptr(), vocab * 4, batch) + + +def fill_next_token_bitmask_par_with_draft_tokens(executor: LLExecutor, + matchers: List[Tuple[LLMatcher, int, List[int]]], + bitmask: torch.Tensor) -> None: + assert bitmask.dtype == torch.int32, "Mask must be int32" + assert bitmask.is_cpu, "Mask must be on CPU" + assert bitmask.dim() == 2, "Mask must be 2D" + batch, vocab = bitmask.shape + assert bitmask.is_contiguous(), "Mask must be contiguous" + executor.unsafe_compute_mask_ptr_with_draft_token(matchers, bitmask.data_ptr(), vocab * 4, batch) diff --git a/python/torch_tests/test_matcher.py b/python/torch_tests/test_matcher.py index 51f80e51..2658e684 100644 --- a/python/torch_tests/test_matcher.py +++ b/python/torch_tests/test_matcher.py @@ -1,6 +1,11 @@ from typing import Any, Dict, List, Tuple, Union import llguidance -from llguidance.numpy import fill_next_token_bitmask_par, allocate_token_bitmask +from llguidance.numpy import ( + fill_next_token_bitmask_par, + fill_next_token_bitmask_par_with_draft_tokens, + allocate_token_bitmask, +) + from llguidance import LLMatcher, LLTokenizer, StructTag, LLParserLimits import pytest from numpy.typing import NDArray @@ -206,6 +211,40 @@ def test_par_errors() -> None: assert not mask_has(mask[2, :], t_a) assert mask_has(mask[2, :], t_1) +def test_par_draft_tokens() -> None: + t = tokenizer() + exec = llguidance.LLExecutor() + g0 = matcher(r"start: /[a-zA-Z]*/") + g1 = matcher(r"start: /[0-9]*/") + + # should be OK + t_a = t.tokenize_str("ab") + t_1 = t.tokenize_str("12") + mask = allocate_token_bitmask(len(t_a) + 1 + len(t_1) + 1, t.vocab_size) + fill_next_token_bitmask_par_with_draft_tokens(exec, [(g0, 0, t_a), (g1, 3, t_1)], mask) + + assert mask_has(mask[0, :], t_a[0]) + assert mask_has(mask[1, :], t_a[1]) + assert mask_has(mask[2, :], t_a[0]) + assert mask_has(mask[2, :], t_a[1]) + assert not mask_has(mask[0, :], t_1[0]) + assert not mask_has(mask[0, :], t_1[1]) + assert not mask_has(mask[1, :], t_1[0]) + assert not mask_has(mask[1, :], t_1[1]) + assert not mask_has(mask[2, :], t_1[0]) + assert not mask_has(mask[2, :], t_1[1]) + + assert mask_has(mask[3, :], t_1[0]) + assert mask_has(mask[4, :], t_1[1]) + assert mask_has(mask[5, :], t_1[0]) + assert mask_has(mask[5, :], t_1[1]) + assert not mask_has(mask[3, :], t_a[0]) + assert not mask_has(mask[3, :], t_a[1]) + assert not mask_has(mask[4, :], t_a[0]) + assert not mask_has(mask[4, :], t_a[1]) + assert not mask_has(mask[5, :], t_a[0]) + assert not mask_has(mask[5, :], t_a[1]) + def consume_tokens(m: LLMatcher, tokens: List[int]) -> None: print("Consume", tokenizer().dbg_tokens(tokens)) diff --git a/python_ext/src/llmatcher.rs b/python_ext/src/llmatcher.rs index 04e3d08b..e29cd77c 100644 --- a/python_ext/src/llmatcher.rs +++ b/python_ext/src/llmatcher.rs @@ -100,6 +100,73 @@ impl LLExecutor { Ok(()) } + + fn unsafe_compute_mask_ptr_with_draft_token( + &self, + interpreters: Bound<'_, PyList>, + trg_ptr: usize, + one_mask_bytes: usize, + trg_batch_size: usize, + py: Python<'_>, + ) -> PyResult<()> { + if interpreters.len() == 0 { + return Err(PyValueError::new_err("No interpreters")); + } + + let mut mut_refs = vec![]; + for ent in interpreters.iter() { + let tupl = ent.downcast::()?; + if tupl.len() != 3 { + return Err(PyValueError::new_err( + "Expecting (LLMatcher, int, List[int]) tuple", + )); + } + let interp = tupl.get_item(0)?.extract::>()?; + let idx = tupl.get_item(1)?.extract::()?; + if idx >= trg_batch_size { + return Err(PyValueError::new_err("Target index out of bounds")); + } + let draft_tokens = tupl.get_item(2)?.extract::>()?; + if draft_tokens.len() == 0 { + return Err(PyValueError::new_err("Draft tokens must not be empty")); + } + interp.validate_mask_ptr(trg_ptr, one_mask_bytes)?; + mut_refs.push((interp, idx, draft_tokens)); + } + + if mut_refs.len() == 1 { + let (mut interp, idx, draft_tokens) = mut_refs.pop().unwrap(); + return interp.unsafe_compute_mask_ptr_with_draft_token( + trg_ptr + idx * one_mask_bytes, + one_mask_bytes, + draft_tokens, + py, + ); + } + + let mut_refs2: Vec<_> = mut_refs + .iter_mut() + .map(|(x, idx, draft_tokens)| (x.deref_mut(), *idx, draft_tokens.clone())) + .collect(); + + use rayon::prelude::*; + + py.allow_threads(|| { + self.pool.install(|| { + mut_refs2 + .into_par_iter() + .for_each(|(interp, idx, draft_tokens)| { + interp.unsafe_compute_mask_ptr_inner_with_draft_tokens( + trg_ptr + idx * one_mask_bytes, + one_mask_bytes, + draft_tokens, + ); + }) + }) + }); + + Ok(()) + } } impl LLMatcher { @@ -125,6 +192,33 @@ impl LLMatcher { trg_slice.copy_from_slice(&src[0..trg_slice.len()]); } + fn unsafe_compute_mask_ptr_inner_with_draft_tokens( + &mut self, + trg_ptr: usize, + trg_bytes: usize, + draft_tokens: Vec, + ) { + let mut state_advancements = 0; + let spec_k = draft_tokens.len(); + for token_idx in 0..=spec_k { + self.unsafe_compute_mask_ptr_inner(trg_ptr + token_idx * trg_bytes, trg_bytes); + + if token_idx == spec_k || self.inner.is_stopped() { + break; + } + + let token = draft_tokens[token_idx]; + + match self.inner.try_consume_tokens(&[token]) { + Ok(cosumed) if cosumed > 0 => state_advancements += 1, + _ => break, + } + } + if state_advancements > 0 { + self.rollback(state_advancements); + } + } + fn eos_token_set(&self) -> SimpleVob { let trie = self.tok_env.tok_trie(); trie.singleton_token_set(trie.eos_token()) @@ -337,6 +431,20 @@ impl LLMatcher { Ok(()) } + fn unsafe_compute_mask_ptr_with_draft_token( + &mut self, + trg_ptr: usize, + trg_bytes: usize, + draft_tokens: Vec, + py: Python<'_>, + ) -> PyResult<()> { + self.validate_mask_ptr(trg_ptr, trg_bytes)?; + py.allow_threads(|| { + self.unsafe_compute_mask_ptr_inner_with_draft_tokens(trg_ptr, trg_bytes, draft_tokens) + }); + Ok(()) + } + fn compute_logit_bias(&mut self, py: Python<'_>) -> Cow<[u8]> { py.allow_threads(|| { let m = self.compute_mask_or_eos(); From 5dd26a6bc8d7b48532673ca4dfc880d6d69295b6 Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Thu, 7 Aug 2025 04:29:32 +0800 Subject: [PATCH 2/8] update more explicit comment on unsafe_compute_mask_ptr_with_draft_token func --- python/llguidance/_lib.pyi | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index 0599fc8b..d3a58e8a 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -546,12 +546,28 @@ class LLExecutor: trg_batch_size: int, ) -> None: """ - Used for speculative decoding. - Compute the token mask directly into memory at the specified pointer. - For each matcher, provide the index of the target mask and a list of draft tokens. - If index is K, the memory will be written at tgt_pointer + K * one_mask_byte_size, - where K < trg_batch_size. - Memory has to have size trg_batch_size * one_mask_byte_size. + Compute the token mask directly into memory at the specified pointer, including draft tokens. + + This function extends unsafe_compute_mask_ptr() to handle draft tokens in speculative decoding. + For each matcher in the batch, it computes masks for both the current position and all draft tokens. + + Args: + interpreters: List of tuples containing: + - LLMatcher: The matcher object for constrained generation + - int: Index K indicating the target mask position (K < trg_batch_size) + - List[int]: Draft tokens to be processed for speculative decoding + tgt_pointer: Memory address where mask data will be written + one_mask_byte_size: Size in bytes of a single token mask + trg_batch_size: Total batch size for memory allocation validation + + Memory Layout: + - Main mask written at: tgt_pointer + K * one_mask_byte_size + - Draft token i mask written at: tgt_pointer + (K + i + 1) * one_mask_byte_size + - Total memory required: trg_batch_size * (spec_k + 1) * one_mask_byte_size + + The function processes each matcher's draft tokens sequentially, advancing the matcher state + for each valid token until encountering an invalid token or termination condition. + State rollback is performed to maintain matcher consistency. """ class JsonCompileOptions(TypedDict, total=False): From 40afda190e5b91fc7982a59a029538a40e1ff382 Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Thu, 7 Aug 2025 14:58:05 +0800 Subject: [PATCH 3/8] update test for test_par_draft_tokens --- python/torch_tests/test_matcher.py | 164 ++++++++++++++++++++++++----- 1 file changed, 135 insertions(+), 29 deletions(-) diff --git a/python/torch_tests/test_matcher.py b/python/torch_tests/test_matcher.py index 2658e684..5ec1c1c6 100644 --- a/python/torch_tests/test_matcher.py +++ b/python/torch_tests/test_matcher.py @@ -161,7 +161,8 @@ def test_slices() -> None: def mask_has(mask: NDArray[np.int32], t: int) -> bool: v: int = mask[t // 32] - return v & (1 << (t % 32)) != 0 + # use np.int32 to avoid int32 overflow errors + return v & (np.int32(1) << (t % 32)) != 0 def test_par_errors() -> None: @@ -211,39 +212,144 @@ def test_par_errors() -> None: assert not mask_has(mask[2, :], t_a) assert mask_has(mask[2, :], t_1) + +def retrieve_tokens_from_bitmask(bitmask: NDArray[np.int32], vocab_size) -> List[List[int]]: + batch_accepted_tokens = [] + batch_rejected_tokens = [] + for batch_idx in range(bitmask.shape[0]): + batch_accepted_tokens.append([]) + batch_rejected_tokens.append([]) + for token_id in range(vocab_size): + print(bitmask.shape) + if mask_has(bitmask[batch_idx], token_id): + batch_accepted_tokens[-1].append(token_id) + else: + batch_rejected_tokens[-1].append(token_id) + return batch_accepted_tokens, batch_rejected_tokens + + def test_par_draft_tokens() -> None: t = tokenizer() exec = llguidance.LLExecutor() - g0 = matcher(r"start: /[a-zA-Z]*/") - g1 = matcher(r"start: /[0-9]*/") + g0 = matcher(r"start: /[a-zA-Z]/ /[0-9]*/") + g1 = matcher(r"start: /[0-9]/ /[a-zA-Z]*/") + g2 = matcher(r"start: <[*]>*") + g3 = matcher(r"start: /[a-zA-Z]/ /[0-9]*/") # should be OK - t_a = t.tokenize_str("ab") - t_1 = t.tokenize_str("12") - mask = allocate_token_bitmask(len(t_a) + 1 + len(t_1) + 1, t.vocab_size) - fill_next_token_bitmask_par_with_draft_tokens(exec, [(g0, 0, t_a), (g1, 3, t_1)], mask) - - assert mask_has(mask[0, :], t_a[0]) - assert mask_has(mask[1, :], t_a[1]) - assert mask_has(mask[2, :], t_a[0]) - assert mask_has(mask[2, :], t_a[1]) - assert not mask_has(mask[0, :], t_1[0]) - assert not mask_has(mask[0, :], t_1[1]) - assert not mask_has(mask[1, :], t_1[0]) - assert not mask_has(mask[1, :], t_1[1]) - assert not mask_has(mask[2, :], t_1[0]) - assert not mask_has(mask[2, :], t_1[1]) - - assert mask_has(mask[3, :], t_1[0]) - assert mask_has(mask[4, :], t_1[1]) - assert mask_has(mask[5, :], t_1[0]) - assert mask_has(mask[5, :], t_1[1]) - assert not mask_has(mask[3, :], t_a[0]) - assert not mask_has(mask[3, :], t_a[1]) - assert not mask_has(mask[4, :], t_a[0]) - assert not mask_has(mask[4, :], t_a[1]) - assert not mask_has(mask[5, :], t_a[0]) - assert not mask_has(mask[5, :], t_a[1]) + g0_draft_tokens = t.tokenize_str("a1") + g1_draft_tokens = t.tokenize_str("2b") + g2_draft_tokens = t.tokenize_str("cc") + # g3 index 1 draft is reject + g3_draft_tokens = t.tokenize_str("aa") + mask = allocate_token_bitmask( + len(g0_draft_tokens) + + 1 + + len(g1_draft_tokens) + + 1 + + len(g2_draft_tokens) + + 1 + + len(g3_draft_tokens) + + 1, + t.vocab_size, + ) + fill_next_token_bitmask_par_with_draft_tokens( + exec, + [ + (g0, 0, g0_draft_tokens), + (g1, 3, g1_draft_tokens), + (g2, 6, g2_draft_tokens), + (g3, 9, g3_draft_tokens), + ], + mask, + ) + + batch_accepted_tokens, batch_rejected_tokens = retrieve_tokens_from_bitmask( + mask, t.vocab_size + ) + for batch_idx in range(len(batch_accepted_tokens)): + assert ( + len(batch_accepted_tokens[batch_idx]) + + len(batch_rejected_tokens[batch_idx]) + == t.vocab_size + ) + + # for g0, first token should be Letters + # other tokens should be Numbers + mask_start_idx = 0 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g0_draft_tokens) + 1) + ): + g0_accepted_tokens = batch_accepted_tokens[mask_idx] + for token_id in range(t.vocab_size): + if token_id in g0_accepted_tokens: + if idx == 0: + assert t.decode_str([token_id]).isalpha() + else: + assert token_id == t.eos_token or t.decode_str([token_id]).isdigit() + else: + assert not g0.try_consume_tokens([token_id]) + if idx < len(g0_draft_tokens): + assert g0.consume_token(g0_draft_tokens[idx]) + + # for g1, first token should be Numbers + # other tokens should be Letters + mask_start_idx += len(g0_draft_tokens) + 1 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g1_draft_tokens) + 1) + ): + g1_accepted_tokens = batch_accepted_tokens[mask_idx] + for token_id in range(t.vocab_size): + if token_id in g1_accepted_tokens: + if idx == 0: + assert t.decode_str([token_id]).isdigit() + else: + assert token_id == t.eos_token or t.decode_str([token_id]).isalpha() + else: + assert not g1.try_consume_tokens([token_id]) + if idx < len(g1_draft_tokens): + assert g1.consume_token(g1_draft_tokens[idx]) + + # for g2, all tokens should be accept + mask_start_idx += len(g1_draft_tokens) + 1 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g2_draft_tokens) + 1) + ): + g2_rejected_tokens = batch_rejected_tokens[mask_idx] + g2_accepted_tokens = batch_accepted_tokens[mask_idx] + assert len(g2_rejected_tokens) == 0 + assert len(g2_accepted_tokens) == t.vocab_size + for token_id in range(t.vocab_size): + if token_id in g2_accepted_tokens: + assert mask_has(mask[mask_idx, :], token_id) + else: + assert not mask_has(mask[mask_idx, :], token_id) + if idx < len(g2_draft_tokens): + assert g2.consume_token(g2_draft_tokens[idx]) + + # for g3 + # g3_draft_tokens[0] is accept + # g3_draft_tokens[1] is reject + mask_start_idx += len(g2_draft_tokens) + 1 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g3_draft_tokens) + 1) + ): + g3_rejected_tokens = batch_rejected_tokens[mask_idx] + g3_accepted_tokens = batch_accepted_tokens[mask_idx] + if idx <= 1: + for token_id in range(t.vocab_size): + if token_id in g3_accepted_tokens: + assert mask_has(mask[mask_idx, :], token_id) + else: + assert not mask_has(mask[mask_idx, :], token_id) + if idx == 0: + assert g3.consume_token(g3_draft_tokens[idx]) + else: + assert not g3.consume_token(g3_draft_tokens[idx]) + else: + # the bitmask of all tokens has a bit value of 1. + assert len(g3_rejected_tokens) == 0 + assert len(g3_accepted_tokens) == t.vocab_size def consume_tokens(m: LLMatcher, tokens: List[int]) -> None: From 91bd2bcd657f1692bacfd03bee85e5dbf63f661d Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Fri, 8 Aug 2025 11:17:49 +0800 Subject: [PATCH 4/8] update test for test_par_draft_tokens --- python_ext/src/llmatcher.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python_ext/src/llmatcher.rs b/python_ext/src/llmatcher.rs index e29cd77c..9cecc7ed 100644 --- a/python_ext/src/llmatcher.rs +++ b/python_ext/src/llmatcher.rs @@ -127,7 +127,7 @@ impl LLExecutor { return Err(PyValueError::new_err("Target index out of bounds")); } let draft_tokens = tupl.get_item(2)?.extract::>()?; - if draft_tokens.len() == 0 { + if draft_tokens.is_empty() { return Err(PyValueError::new_err("Draft tokens must not be empty")); } interp.validate_mask_ptr(trg_ptr, one_mask_bytes)?; @@ -200,6 +200,7 @@ impl LLMatcher { ) { let mut state_advancements = 0; let spec_k = draft_tokens.len(); + #[allow(clippy::needless_range_loop)] for token_idx in 0..=spec_k { self.unsafe_compute_mask_ptr_inner(trg_ptr + token_idx * trg_bytes, trg_bytes); From 12c8972bc1ec6b87ffe7818006c416bcf971d8e1 Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Sat, 9 Aug 2025 15:55:49 +0800 Subject: [PATCH 5/8] update test_matcher.py typo fix --- python/torch_tests/test_matcher.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/torch_tests/test_matcher.py b/python/torch_tests/test_matcher.py index 5ec1c1c6..40140719 100644 --- a/python/torch_tests/test_matcher.py +++ b/python/torch_tests/test_matcher.py @@ -162,7 +162,7 @@ def test_slices() -> None: def mask_has(mask: NDArray[np.int32], t: int) -> bool: v: int = mask[t // 32] # use np.int32 to avoid int32 overflow errors - return v & (np.int32(1) << (t % 32)) != 0 + return bool(v & (np.int32(1) << (t % 32)) != 0) def test_par_errors() -> None: @@ -213,9 +213,11 @@ def test_par_errors() -> None: assert mask_has(mask[2, :], t_1) -def retrieve_tokens_from_bitmask(bitmask: NDArray[np.int32], vocab_size) -> List[List[int]]: - batch_accepted_tokens = [] - batch_rejected_tokens = [] +def retrieve_tokens_from_bitmask( + bitmask: NDArray[np.int32], vocab_size: int +) -> Tuple[List[List[int]], List[List[int]]]: + batch_accepted_tokens: List[List[int]] = [] + batch_rejected_tokens: List[List[int]] = [] for batch_idx in range(bitmask.shape[0]): batch_accepted_tokens.append([]) batch_rejected_tokens.append([]) From b89c8ebaefde9c20d15b06266fb02686a60fc216 Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Mon, 11 Aug 2025 03:38:49 +0800 Subject: [PATCH 6/8] typo fix --- python/llguidance/_lib.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index d3a58e8a..ba2bbef2 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -563,7 +563,7 @@ class LLExecutor: Memory Layout: - Main mask written at: tgt_pointer + K * one_mask_byte_size - Draft token i mask written at: tgt_pointer + (K + i + 1) * one_mask_byte_size - - Total memory required: trg_batch_size * (spec_k + 1) * one_mask_byte_size + - Total memory required: trg_batch_size * one_mask_byte_size The function processes each matcher's draft tokens sequentially, advancing the matcher state for each valid token until encountering an invalid token or termination condition. From 380a374701e78ff5c173b5cef874664829b982d2 Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Mon, 11 Aug 2025 11:15:53 +0800 Subject: [PATCH 7/8] typo fix --- python/llguidance/_lib.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index ba2bbef2..819337f8 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -541,7 +541,7 @@ class LLExecutor: def unsafe_compute_mask_ptr_with_draft_token( self, interpreters: List[Tuple[LLMatcher, int, List[int]]], - tgt_pointer: int, + trg_pointer: int, one_mask_byte_size: int, trg_batch_size: int, ) -> None: @@ -556,13 +556,13 @@ class LLExecutor: - LLMatcher: The matcher object for constrained generation - int: Index K indicating the target mask position (K < trg_batch_size) - List[int]: Draft tokens to be processed for speculative decoding - tgt_pointer: Memory address where mask data will be written + trg_pointer: Memory address where mask data will be written one_mask_byte_size: Size in bytes of a single token mask trg_batch_size: Total batch size for memory allocation validation Memory Layout: - - Main mask written at: tgt_pointer + K * one_mask_byte_size - - Draft token i mask written at: tgt_pointer + (K + i + 1) * one_mask_byte_size + - Main mask written at: trg_pointer + K * one_mask_byte_size + - Draft token i mask written at: trg_pointer + (K + i + 1) * one_mask_byte_size - Total memory required: trg_batch_size * one_mask_byte_size The function processes each matcher's draft tokens sequentially, advancing the matcher state From b674b66c63b1bf471d8ef19df0358a769e6ae699 Mon Sep 17 00:00:00 2001 From: "zonepg666@gmail.com" Date: Tue, 12 Aug 2025 13:21:43 +0800 Subject: [PATCH 8/8] rename one_mask_byte_size => one_mask_bytes --- python/llguidance/_lib.pyi | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index 819337f8..82cecf47 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -526,15 +526,15 @@ class LLExecutor: self, interpreters: List[Tuple[LLMatcher, int]], trg_pointer: int, - one_mask_byte_size: int, + one_mask_bytes: int, trg_batch_size: int, ) -> None: """ Compute the token mask directly into memory at the specified pointer. For each matcher, provide the index of the target mask. - If index is K, the memory will be written at trg_pointer + K * one_mask_byte_size, + If index is K, the memory will be written at trg_pointer + K * one_mask_bytes, where K < trg_batch_size. - Memory has to have size trg_batch_size * one_mask_byte_size. + Memory has to have size trg_batch_size * one_mask_bytes. Prefer to use fill_next_token_bitmask_par(), which wraps this. """ @@ -542,12 +542,12 @@ class LLExecutor: self, interpreters: List[Tuple[LLMatcher, int, List[int]]], trg_pointer: int, - one_mask_byte_size: int, + one_mask_bytes: int, trg_batch_size: int, ) -> None: """ Compute the token mask directly into memory at the specified pointer, including draft tokens. - + This function extends unsafe_compute_mask_ptr() to handle draft tokens in speculative decoding. For each matcher in the batch, it computes masks for both the current position and all draft tokens. @@ -557,13 +557,13 @@ class LLExecutor: - int: Index K indicating the target mask position (K < trg_batch_size) - List[int]: Draft tokens to be processed for speculative decoding trg_pointer: Memory address where mask data will be written - one_mask_byte_size: Size in bytes of a single token mask + one_mask_bytes: Size in bytes of a single token mask trg_batch_size: Total batch size for memory allocation validation Memory Layout: - - Main mask written at: trg_pointer + K * one_mask_byte_size - - Draft token i mask written at: trg_pointer + (K + i + 1) * one_mask_byte_size - - Total memory required: trg_batch_size * one_mask_byte_size + - Main mask written at: trg_pointer + K * one_mask_bytes + - Draft token i mask written at: trg_pointer + (K + i + 1) * one_mask_bytes + - Total memory required: trg_batch_size * one_mask_bytes The function processes each matcher's draft tokens sequentially, advancing the matcher state for each valid token until encountering an invalid token or termination condition.