diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index d7fdc311..82cecf47 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -526,18 +526,49 @@ class LLExecutor: self, interpreters: List[Tuple[LLMatcher, int]], trg_pointer: int, - one_mask_byte_size: int, + one_mask_bytes: int, trg_batch_size: int, ) -> None: """ Compute the token mask directly into memory at the specified pointer. For each matcher, provide the index of the target mask. - If index is K, the memory will be written at trg_pointer + K * one_mask_byte_size, + If index is K, the memory will be written at trg_pointer + K * one_mask_bytes, where K < trg_batch_size. - Memory has to have size trg_batch_size * one_mask_byte_size. + Memory has to have size trg_batch_size * one_mask_bytes. Prefer to use fill_next_token_bitmask_par(), which wraps this. """ + def unsafe_compute_mask_ptr_with_draft_token( + self, + interpreters: List[Tuple[LLMatcher, int, List[int]]], + trg_pointer: int, + one_mask_bytes: int, + trg_batch_size: int, + ) -> None: + """ + Compute the token mask directly into memory at the specified pointer, including draft tokens. + + This function extends unsafe_compute_mask_ptr() to handle draft tokens in speculative decoding. + For each matcher in the batch, it computes masks for both the current position and all draft tokens. + + Args: + interpreters: List of tuples containing: + - LLMatcher: The matcher object for constrained generation + - int: Index K indicating the target mask position (K < trg_batch_size) + - List[int]: Draft tokens to be processed for speculative decoding + trg_pointer: Memory address where mask data will be written + one_mask_bytes: Size in bytes of a single token mask + trg_batch_size: Total batch size for memory allocation validation + + Memory Layout: + - Main mask written at: trg_pointer + K * one_mask_bytes + - Draft token i mask written at: trg_pointer + (K + i + 1) * one_mask_bytes + - Total memory required: trg_batch_size * one_mask_bytes + + The function processes each matcher's draft tokens sequentially, advancing the matcher state + for each valid token until encountering an invalid token or termination condition. + State rollback is performed to maintain matcher consistency. + """ class JsonCompileOptions(TypedDict, total=False): # defaults to "," diff --git a/python/llguidance/numpy.py b/python/llguidance/numpy.py index 21452a6a..414c2c51 100644 --- a/python/llguidance/numpy.py +++ b/python/llguidance/numpy.py @@ -66,3 +66,17 @@ def fill_next_token_bitmask_par(executor: LLExecutor, batch, vocab = bitmask.shape assert bitmask.flags["C_CONTIGUOUS"], "Mask must be contiguous" executor.unsafe_compute_mask_ptr(matchers, bitmask.ctypes.data, vocab * 4, batch) + + +def fill_next_token_bitmask_par_with_draft_tokens(executor: LLExecutor, + matchers: List[Tuple[LLMatcher, int, List[int]]], + bitmask: NDArray[np.int32]) -> None: + """ + Compute the token mask directly into the specified array. + For each matcher, provide the index of the target mask. + """ + assert bitmask.dtype == np.int32, "Mask must be int32" + assert bitmask.ndim == 2, "Mask must be 2D" + batch, vocab = bitmask.shape + assert bitmask.flags["C_CONTIGUOUS"], "Mask must be contiguous" + executor.unsafe_compute_mask_ptr_with_draft_token(matchers, bitmask.ctypes.data, vocab * 4, batch) diff --git a/python/llguidance/torch.py b/python/llguidance/torch.py index 2dea212d..f09eaae0 100644 --- a/python/llguidance/torch.py +++ b/python/llguidance/torch.py @@ -66,3 +66,14 @@ def fill_next_token_bitmask_par(executor: LLExecutor, assert bitmask.is_contiguous(), "Mask must be contiguous" executor.unsafe_compute_mask_ptr(matchers, bitmask.data_ptr(), vocab * 4, batch) + + +def fill_next_token_bitmask_par_with_draft_tokens(executor: LLExecutor, + matchers: List[Tuple[LLMatcher, int, List[int]]], + bitmask: torch.Tensor) -> None: + assert bitmask.dtype == torch.int32, "Mask must be int32" + assert bitmask.is_cpu, "Mask must be on CPU" + assert bitmask.dim() == 2, "Mask must be 2D" + batch, vocab = bitmask.shape + assert bitmask.is_contiguous(), "Mask must be contiguous" + executor.unsafe_compute_mask_ptr_with_draft_token(matchers, bitmask.data_ptr(), vocab * 4, batch) diff --git a/python/torch_tests/test_matcher.py b/python/torch_tests/test_matcher.py index 51f80e51..40140719 100644 --- a/python/torch_tests/test_matcher.py +++ b/python/torch_tests/test_matcher.py @@ -1,6 +1,11 @@ from typing import Any, Dict, List, Tuple, Union import llguidance -from llguidance.numpy import fill_next_token_bitmask_par, allocate_token_bitmask +from llguidance.numpy import ( + fill_next_token_bitmask_par, + fill_next_token_bitmask_par_with_draft_tokens, + allocate_token_bitmask, +) + from llguidance import LLMatcher, LLTokenizer, StructTag, LLParserLimits import pytest from numpy.typing import NDArray @@ -156,7 +161,8 @@ def test_slices() -> None: def mask_has(mask: NDArray[np.int32], t: int) -> bool: v: int = mask[t // 32] - return v & (1 << (t % 32)) != 0 + # use np.int32 to avoid int32 overflow errors + return bool(v & (np.int32(1) << (t % 32)) != 0) def test_par_errors() -> None: @@ -207,6 +213,147 @@ def test_par_errors() -> None: assert mask_has(mask[2, :], t_1) +def retrieve_tokens_from_bitmask( + bitmask: NDArray[np.int32], vocab_size: int +) -> Tuple[List[List[int]], List[List[int]]]: + batch_accepted_tokens: List[List[int]] = [] + batch_rejected_tokens: List[List[int]] = [] + for batch_idx in range(bitmask.shape[0]): + batch_accepted_tokens.append([]) + batch_rejected_tokens.append([]) + for token_id in range(vocab_size): + print(bitmask.shape) + if mask_has(bitmask[batch_idx], token_id): + batch_accepted_tokens[-1].append(token_id) + else: + batch_rejected_tokens[-1].append(token_id) + return batch_accepted_tokens, batch_rejected_tokens + + +def test_par_draft_tokens() -> None: + t = tokenizer() + exec = llguidance.LLExecutor() + g0 = matcher(r"start: /[a-zA-Z]/ /[0-9]*/") + g1 = matcher(r"start: /[0-9]/ /[a-zA-Z]*/") + g2 = matcher(r"start: <[*]>*") + g3 = matcher(r"start: /[a-zA-Z]/ /[0-9]*/") + + # should be OK + g0_draft_tokens = t.tokenize_str("a1") + g1_draft_tokens = t.tokenize_str("2b") + g2_draft_tokens = t.tokenize_str("cc") + # g3 index 1 draft is reject + g3_draft_tokens = t.tokenize_str("aa") + mask = allocate_token_bitmask( + len(g0_draft_tokens) + + 1 + + len(g1_draft_tokens) + + 1 + + len(g2_draft_tokens) + + 1 + + len(g3_draft_tokens) + + 1, + t.vocab_size, + ) + fill_next_token_bitmask_par_with_draft_tokens( + exec, + [ + (g0, 0, g0_draft_tokens), + (g1, 3, g1_draft_tokens), + (g2, 6, g2_draft_tokens), + (g3, 9, g3_draft_tokens), + ], + mask, + ) + + batch_accepted_tokens, batch_rejected_tokens = retrieve_tokens_from_bitmask( + mask, t.vocab_size + ) + for batch_idx in range(len(batch_accepted_tokens)): + assert ( + len(batch_accepted_tokens[batch_idx]) + + len(batch_rejected_tokens[batch_idx]) + == t.vocab_size + ) + + # for g0, first token should be Letters + # other tokens should be Numbers + mask_start_idx = 0 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g0_draft_tokens) + 1) + ): + g0_accepted_tokens = batch_accepted_tokens[mask_idx] + for token_id in range(t.vocab_size): + if token_id in g0_accepted_tokens: + if idx == 0: + assert t.decode_str([token_id]).isalpha() + else: + assert token_id == t.eos_token or t.decode_str([token_id]).isdigit() + else: + assert not g0.try_consume_tokens([token_id]) + if idx < len(g0_draft_tokens): + assert g0.consume_token(g0_draft_tokens[idx]) + + # for g1, first token should be Numbers + # other tokens should be Letters + mask_start_idx += len(g0_draft_tokens) + 1 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g1_draft_tokens) + 1) + ): + g1_accepted_tokens = batch_accepted_tokens[mask_idx] + for token_id in range(t.vocab_size): + if token_id in g1_accepted_tokens: + if idx == 0: + assert t.decode_str([token_id]).isdigit() + else: + assert token_id == t.eos_token or t.decode_str([token_id]).isalpha() + else: + assert not g1.try_consume_tokens([token_id]) + if idx < len(g1_draft_tokens): + assert g1.consume_token(g1_draft_tokens[idx]) + + # for g2, all tokens should be accept + mask_start_idx += len(g1_draft_tokens) + 1 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g2_draft_tokens) + 1) + ): + g2_rejected_tokens = batch_rejected_tokens[mask_idx] + g2_accepted_tokens = batch_accepted_tokens[mask_idx] + assert len(g2_rejected_tokens) == 0 + assert len(g2_accepted_tokens) == t.vocab_size + for token_id in range(t.vocab_size): + if token_id in g2_accepted_tokens: + assert mask_has(mask[mask_idx, :], token_id) + else: + assert not mask_has(mask[mask_idx, :], token_id) + if idx < len(g2_draft_tokens): + assert g2.consume_token(g2_draft_tokens[idx]) + + # for g3 + # g3_draft_tokens[0] is accept + # g3_draft_tokens[1] is reject + mask_start_idx += len(g2_draft_tokens) + 1 + for idx, mask_idx in enumerate( + range(mask_start_idx, mask_start_idx + len(g3_draft_tokens) + 1) + ): + g3_rejected_tokens = batch_rejected_tokens[mask_idx] + g3_accepted_tokens = batch_accepted_tokens[mask_idx] + if idx <= 1: + for token_id in range(t.vocab_size): + if token_id in g3_accepted_tokens: + assert mask_has(mask[mask_idx, :], token_id) + else: + assert not mask_has(mask[mask_idx, :], token_id) + if idx == 0: + assert g3.consume_token(g3_draft_tokens[idx]) + else: + assert not g3.consume_token(g3_draft_tokens[idx]) + else: + # the bitmask of all tokens has a bit value of 1. + assert len(g3_rejected_tokens) == 0 + assert len(g3_accepted_tokens) == t.vocab_size + + def consume_tokens(m: LLMatcher, tokens: List[int]) -> None: print("Consume", tokenizer().dbg_tokens(tokens)) assert m.stop_reason() == "NotStopped" diff --git a/python_ext/src/llmatcher.rs b/python_ext/src/llmatcher.rs index 04e3d08b..9cecc7ed 100644 --- a/python_ext/src/llmatcher.rs +++ b/python_ext/src/llmatcher.rs @@ -100,6 +100,73 @@ impl LLExecutor { Ok(()) } + + fn unsafe_compute_mask_ptr_with_draft_token( + &self, + interpreters: Bound<'_, PyList>, + trg_ptr: usize, + one_mask_bytes: usize, + trg_batch_size: usize, + py: Python<'_>, + ) -> PyResult<()> { + if interpreters.len() == 0 { + return Err(PyValueError::new_err("No interpreters")); + } + + let mut mut_refs = vec![]; + for ent in interpreters.iter() { + let tupl = ent.downcast::()?; + if tupl.len() != 3 { + return Err(PyValueError::new_err( + "Expecting (LLMatcher, int, List[int]) tuple", + )); + } + let interp = tupl.get_item(0)?.extract::>()?; + let idx = tupl.get_item(1)?.extract::()?; + if idx >= trg_batch_size { + return Err(PyValueError::new_err("Target index out of bounds")); + } + let draft_tokens = tupl.get_item(2)?.extract::>()?; + if draft_tokens.is_empty() { + return Err(PyValueError::new_err("Draft tokens must not be empty")); + } + interp.validate_mask_ptr(trg_ptr, one_mask_bytes)?; + mut_refs.push((interp, idx, draft_tokens)); + } + + if mut_refs.len() == 1 { + let (mut interp, idx, draft_tokens) = mut_refs.pop().unwrap(); + return interp.unsafe_compute_mask_ptr_with_draft_token( + trg_ptr + idx * one_mask_bytes, + one_mask_bytes, + draft_tokens, + py, + ); + } + + let mut_refs2: Vec<_> = mut_refs + .iter_mut() + .map(|(x, idx, draft_tokens)| (x.deref_mut(), *idx, draft_tokens.clone())) + .collect(); + + use rayon::prelude::*; + + py.allow_threads(|| { + self.pool.install(|| { + mut_refs2 + .into_par_iter() + .for_each(|(interp, idx, draft_tokens)| { + interp.unsafe_compute_mask_ptr_inner_with_draft_tokens( + trg_ptr + idx * one_mask_bytes, + one_mask_bytes, + draft_tokens, + ); + }) + }) + }); + + Ok(()) + } } impl LLMatcher { @@ -125,6 +192,34 @@ impl LLMatcher { trg_slice.copy_from_slice(&src[0..trg_slice.len()]); } + fn unsafe_compute_mask_ptr_inner_with_draft_tokens( + &mut self, + trg_ptr: usize, + trg_bytes: usize, + draft_tokens: Vec, + ) { + let mut state_advancements = 0; + let spec_k = draft_tokens.len(); + #[allow(clippy::needless_range_loop)] + for token_idx in 0..=spec_k { + self.unsafe_compute_mask_ptr_inner(trg_ptr + token_idx * trg_bytes, trg_bytes); + + if token_idx == spec_k || self.inner.is_stopped() { + break; + } + + let token = draft_tokens[token_idx]; + + match self.inner.try_consume_tokens(&[token]) { + Ok(cosumed) if cosumed > 0 => state_advancements += 1, + _ => break, + } + } + if state_advancements > 0 { + self.rollback(state_advancements); + } + } + fn eos_token_set(&self) -> SimpleVob { let trie = self.tok_env.tok_trie(); trie.singleton_token_set(trie.eos_token()) @@ -337,6 +432,20 @@ impl LLMatcher { Ok(()) } + fn unsafe_compute_mask_ptr_with_draft_token( + &mut self, + trg_ptr: usize, + trg_bytes: usize, + draft_tokens: Vec, + py: Python<'_>, + ) -> PyResult<()> { + self.validate_mask_ptr(trg_ptr, trg_bytes)?; + py.allow_threads(|| { + self.unsafe_compute_mask_ptr_inner_with_draft_tokens(trg_ptr, trg_bytes, draft_tokens) + }); + Ok(()) + } + fn compute_logit_bias(&mut self, py: Python<'_>) -> Cow<[u8]> { py.allow_threads(|| { let m = self.compute_mask_or_eos();