Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -526,18 +526,49 @@ class LLExecutor:
self,
interpreters: List[Tuple[LLMatcher, int]],
trg_pointer: int,
one_mask_byte_size: int,
one_mask_bytes: int,
trg_batch_size: int,
) -> None:
"""
Compute the token mask directly into memory at the specified pointer.
For each matcher, provide the index of the target mask.
If index is K, the memory will be written at trg_pointer + K * one_mask_byte_size,
If index is K, the memory will be written at trg_pointer + K * one_mask_bytes,
where K < trg_batch_size.
Memory has to have size trg_batch_size * one_mask_byte_size.
Memory has to have size trg_batch_size * one_mask_bytes.
Prefer to use fill_next_token_bitmask_par(), which wraps this.
"""

def unsafe_compute_mask_ptr_with_draft_token(
Comment thread
hudson-ai marked this conversation as resolved.
self,
interpreters: List[Tuple[LLMatcher, int, List[int]]],
trg_pointer: int,
one_mask_bytes: int,
trg_batch_size: int,
) -> None:
"""
Compute the token mask directly into memory at the specified pointer, including draft tokens.

This function extends unsafe_compute_mask_ptr() to handle draft tokens in speculative decoding.
For each matcher in the batch, it computes masks for both the current position and all draft tokens.

Args:
interpreters: List of tuples containing:
- LLMatcher: The matcher object for constrained generation
- int: Index K indicating the target mask position (K < trg_batch_size)
- List[int]: Draft tokens to be processed for speculative decoding
trg_pointer: Memory address where mask data will be written
one_mask_bytes: Size in bytes of a single token mask
trg_batch_size: Total batch size for memory allocation validation

Memory Layout:
- Main mask written at: trg_pointer + K * one_mask_bytes
- Draft token i mask written at: trg_pointer + (K + i + 1) * one_mask_bytes
- Total memory required: trg_batch_size * one_mask_bytes

The function processes each matcher's draft tokens sequentially, advancing the matcher state
for each valid token until encountering an invalid token or termination condition.
State rollback is performed to maintain matcher consistency.
"""

class JsonCompileOptions(TypedDict, total=False):
# defaults to ","
Expand Down
14 changes: 14 additions & 0 deletions python/llguidance/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,17 @@ def fill_next_token_bitmask_par(executor: LLExecutor,
batch, vocab = bitmask.shape
assert bitmask.flags["C_CONTIGUOUS"], "Mask must be contiguous"
executor.unsafe_compute_mask_ptr(matchers, bitmask.ctypes.data, vocab * 4, batch)


def fill_next_token_bitmask_par_with_draft_tokens(executor: LLExecutor,
matchers: List[Tuple[LLMatcher, int, List[int]]],
bitmask: NDArray[np.int32]) -> None:
"""
Compute the token mask directly into the specified array.
For each matcher, provide the index of the target mask.
"""
assert bitmask.dtype == np.int32, "Mask must be int32"
assert bitmask.ndim == 2, "Mask must be 2D"
batch, vocab = bitmask.shape
assert bitmask.flags["C_CONTIGUOUS"], "Mask must be contiguous"
executor.unsafe_compute_mask_ptr_with_draft_token(matchers, bitmask.ctypes.data, vocab * 4, batch)
11 changes: 11 additions & 0 deletions python/llguidance/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ def fill_next_token_bitmask_par(executor: LLExecutor,
assert bitmask.is_contiguous(), "Mask must be contiguous"
executor.unsafe_compute_mask_ptr(matchers, bitmask.data_ptr(), vocab * 4,
batch)


def fill_next_token_bitmask_par_with_draft_tokens(executor: LLExecutor,
matchers: List[Tuple[LLMatcher, int, List[int]]],
bitmask: torch.Tensor) -> None:
assert bitmask.dtype == torch.int32, "Mask must be int32"
assert bitmask.is_cpu, "Mask must be on CPU"
assert bitmask.dim() == 2, "Mask must be 2D"
batch, vocab = bitmask.shape
assert bitmask.is_contiguous(), "Mask must be contiguous"
executor.unsafe_compute_mask_ptr_with_draft_token(matchers, bitmask.data_ptr(), vocab * 4, batch)
151 changes: 149 additions & 2 deletions python/torch_tests/test_matcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Dict, List, Tuple, Union
import llguidance
from llguidance.numpy import fill_next_token_bitmask_par, allocate_token_bitmask
from llguidance.numpy import (
fill_next_token_bitmask_par,
fill_next_token_bitmask_par_with_draft_tokens,
allocate_token_bitmask,
)

from llguidance import LLMatcher, LLTokenizer, StructTag, LLParserLimits
import pytest
from numpy.typing import NDArray
Expand Down Expand Up @@ -156,7 +161,8 @@ def test_slices() -> None:

def mask_has(mask: NDArray[np.int32], t: int) -> bool:
v: int = mask[t // 32]
return v & (1 << (t % 32)) != 0
# use np.int32 to avoid int32 overflow errors
return bool(v & (np.int32(1) << (t % 32)) != 0)


def test_par_errors() -> None:
Expand Down Expand Up @@ -207,6 +213,147 @@ def test_par_errors() -> None:
assert mask_has(mask[2, :], t_1)


def retrieve_tokens_from_bitmask(
bitmask: NDArray[np.int32], vocab_size: int
) -> Tuple[List[List[int]], List[List[int]]]:
batch_accepted_tokens: List[List[int]] = []
batch_rejected_tokens: List[List[int]] = []
for batch_idx in range(bitmask.shape[0]):
batch_accepted_tokens.append([])
batch_rejected_tokens.append([])
for token_id in range(vocab_size):
print(bitmask.shape)
if mask_has(bitmask[batch_idx], token_id):
batch_accepted_tokens[-1].append(token_id)
else:
batch_rejected_tokens[-1].append(token_id)
return batch_accepted_tokens, batch_rejected_tokens


def test_par_draft_tokens() -> None:
t = tokenizer()
exec = llguidance.LLExecutor()
g0 = matcher(r"start: /[a-zA-Z]/ /[0-9]*/")
g1 = matcher(r"start: /[0-9]/ /[a-zA-Z]*/")
g2 = matcher(r"start: <[*]>*")
g3 = matcher(r"start: /[a-zA-Z]/ /[0-9]*/")

# should be OK
g0_draft_tokens = t.tokenize_str("a1")
g1_draft_tokens = t.tokenize_str("2b")
g2_draft_tokens = t.tokenize_str("cc")
# g3 index 1 draft is reject
g3_draft_tokens = t.tokenize_str("aa")
mask = allocate_token_bitmask(
len(g0_draft_tokens)
+ 1
+ len(g1_draft_tokens)
+ 1
+ len(g2_draft_tokens)
+ 1
+ len(g3_draft_tokens)
+ 1,
t.vocab_size,
)
fill_next_token_bitmask_par_with_draft_tokens(
exec,
[
(g0, 0, g0_draft_tokens),
(g1, 3, g1_draft_tokens),
(g2, 6, g2_draft_tokens),
(g3, 9, g3_draft_tokens),
],
mask,
)

batch_accepted_tokens, batch_rejected_tokens = retrieve_tokens_from_bitmask(
mask, t.vocab_size
)
for batch_idx in range(len(batch_accepted_tokens)):
assert (
len(batch_accepted_tokens[batch_idx])
+ len(batch_rejected_tokens[batch_idx])
== t.vocab_size
)

# for g0, first token should be Letters
# other tokens should be Numbers
mask_start_idx = 0
for idx, mask_idx in enumerate(
range(mask_start_idx, mask_start_idx + len(g0_draft_tokens) + 1)
):
g0_accepted_tokens = batch_accepted_tokens[mask_idx]
for token_id in range(t.vocab_size):
if token_id in g0_accepted_tokens:
if idx == 0:
assert t.decode_str([token_id]).isalpha()
else:
assert token_id == t.eos_token or t.decode_str([token_id]).isdigit()
else:
assert not g0.try_consume_tokens([token_id])
if idx < len(g0_draft_tokens):
assert g0.consume_token(g0_draft_tokens[idx])

# for g1, first token should be Numbers
# other tokens should be Letters
mask_start_idx += len(g0_draft_tokens) + 1
for idx, mask_idx in enumerate(
range(mask_start_idx, mask_start_idx + len(g1_draft_tokens) + 1)
):
g1_accepted_tokens = batch_accepted_tokens[mask_idx]
for token_id in range(t.vocab_size):
if token_id in g1_accepted_tokens:
if idx == 0:
assert t.decode_str([token_id]).isdigit()
else:
assert token_id == t.eos_token or t.decode_str([token_id]).isalpha()
else:
assert not g1.try_consume_tokens([token_id])
if idx < len(g1_draft_tokens):
assert g1.consume_token(g1_draft_tokens[idx])

# for g2, all tokens should be accept
mask_start_idx += len(g1_draft_tokens) + 1
for idx, mask_idx in enumerate(
range(mask_start_idx, mask_start_idx + len(g2_draft_tokens) + 1)
):
g2_rejected_tokens = batch_rejected_tokens[mask_idx]
g2_accepted_tokens = batch_accepted_tokens[mask_idx]
assert len(g2_rejected_tokens) == 0
assert len(g2_accepted_tokens) == t.vocab_size
for token_id in range(t.vocab_size):
if token_id in g2_accepted_tokens:
assert mask_has(mask[mask_idx, :], token_id)
else:
assert not mask_has(mask[mask_idx, :], token_id)
if idx < len(g2_draft_tokens):
assert g2.consume_token(g2_draft_tokens[idx])

# for g3
# g3_draft_tokens[0] is accept
# g3_draft_tokens[1] is reject
mask_start_idx += len(g2_draft_tokens) + 1
for idx, mask_idx in enumerate(
range(mask_start_idx, mask_start_idx + len(g3_draft_tokens) + 1)
):
g3_rejected_tokens = batch_rejected_tokens[mask_idx]
g3_accepted_tokens = batch_accepted_tokens[mask_idx]
if idx <= 1:
for token_id in range(t.vocab_size):
if token_id in g3_accepted_tokens:
assert mask_has(mask[mask_idx, :], token_id)
else:
assert not mask_has(mask[mask_idx, :], token_id)
if idx == 0:
assert g3.consume_token(g3_draft_tokens[idx])
else:
assert not g3.consume_token(g3_draft_tokens[idx])
else:
# the bitmask of all tokens has a bit value of 1.
assert len(g3_rejected_tokens) == 0
assert len(g3_accepted_tokens) == t.vocab_size


def consume_tokens(m: LLMatcher, tokens: List[int]) -> None:
print("Consume", tokenizer().dbg_tokens(tokens))
assert m.stop_reason() == "NotStopped"
Expand Down
109 changes: 109 additions & 0 deletions python_ext/src/llmatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,73 @@ impl LLExecutor {

Ok(())
}

fn unsafe_compute_mask_ptr_with_draft_token(
&self,
interpreters: Bound<'_, PyList>,
trg_ptr: usize,
one_mask_bytes: usize,
trg_batch_size: usize,
py: Python<'_>,
) -> PyResult<()> {
if interpreters.len() == 0 {
return Err(PyValueError::new_err("No interpreters"));
}

let mut mut_refs = vec![];
for ent in interpreters.iter() {
let tupl = ent.downcast::<PyTuple>()?;
if tupl.len() != 3 {
return Err(PyValueError::new_err(
"Expecting (LLMatcher, int, List[int]) tuple",
));
}
let interp = tupl.get_item(0)?.extract::<PyRefMut<LLMatcher>>()?;
let idx = tupl.get_item(1)?.extract::<usize>()?;
if idx >= trg_batch_size {
return Err(PyValueError::new_err("Target index out of bounds"));
}
let draft_tokens = tupl.get_item(2)?.extract::<Vec<TokenId>>()?;
if draft_tokens.is_empty() {
return Err(PyValueError::new_err("Draft tokens must not be empty"));
}
interp.validate_mask_ptr(trg_ptr, one_mask_bytes)?;
mut_refs.push((interp, idx, draft_tokens));
}

if mut_refs.len() == 1 {
let (mut interp, idx, draft_tokens) = mut_refs.pop().unwrap();
return interp.unsafe_compute_mask_ptr_with_draft_token(
trg_ptr + idx * one_mask_bytes,
one_mask_bytes,
draft_tokens,
py,
);
}

let mut_refs2: Vec<_> = mut_refs
.iter_mut()
.map(|(x, idx, draft_tokens)| (x.deref_mut(), *idx, draft_tokens.clone()))
.collect();

use rayon::prelude::*;

py.allow_threads(|| {
self.pool.install(|| {
mut_refs2
.into_par_iter()
.for_each(|(interp, idx, draft_tokens)| {
interp.unsafe_compute_mask_ptr_inner_with_draft_tokens(
trg_ptr + idx * one_mask_bytes,
one_mask_bytes,
draft_tokens,
);
})
})
});

Ok(())
}
}

impl LLMatcher {
Expand All @@ -125,6 +192,34 @@ impl LLMatcher {
trg_slice.copy_from_slice(&src[0..trg_slice.len()]);
}

fn unsafe_compute_mask_ptr_inner_with_draft_tokens(
&mut self,
trg_ptr: usize,
trg_bytes: usize,
draft_tokens: Vec<TokenId>,
) {
let mut state_advancements = 0;
let spec_k = draft_tokens.len();
#[allow(clippy::needless_range_loop)]
for token_idx in 0..=spec_k {
self.unsafe_compute_mask_ptr_inner(trg_ptr + token_idx * trg_bytes, trg_bytes);

if token_idx == spec_k || self.inner.is_stopped() {
break;
}

let token = draft_tokens[token_idx];

match self.inner.try_consume_tokens(&[token]) {
Ok(cosumed) if cosumed > 0 => state_advancements += 1,
_ => break,
}
}
if state_advancements > 0 {
self.rollback(state_advancements);
}
}

fn eos_token_set(&self) -> SimpleVob {
let trie = self.tok_env.tok_trie();
trie.singleton_token_set(trie.eos_token())
Expand Down Expand Up @@ -337,6 +432,20 @@ impl LLMatcher {
Ok(())
}

fn unsafe_compute_mask_ptr_with_draft_token(
&mut self,
trg_ptr: usize,
trg_bytes: usize,
draft_tokens: Vec<TokenId>,
py: Python<'_>,
) -> PyResult<()> {
self.validate_mask_ptr(trg_ptr, trg_bytes)?;
py.allow_threads(|| {
self.unsafe_compute_mask_ptr_inner_with_draft_tokens(trg_ptr, trg_bytes, draft_tokens)
});
Ok(())
}

fn compute_logit_bias(&mut self, py: Python<'_>) -> Cow<[u8]> {
py.allow_threads(|| {
let m = self.compute_mask_or_eos();
Expand Down