diff --git a/c_sample/c_sample.cpp b/c_sample/c_sample.cpp index 94cacf14..62b82767 100644 --- a/c_sample/c_sample.cpp +++ b/c_sample/c_sample.cpp @@ -9,28 +9,73 @@ #include "llguidance.h" +// Create an LlgTokenizer using the v2 API. +// eos_tokens[0] is the primary EOS; any remaining entries are extra EOS token IDs. +LlgTokenizer *create_tokenizer_v2(std::vector> &tokens, + std::vector eos_tokens, + LlgTokenizeFn tokenize_fn, + const void *tokenize_user_data) { + assert(!eos_tokens.empty()); + std::vector token_lens(tokens.size()); + size_t total_size = 0; + for (size_t i = 0; i < tokens.size(); i++) { + token_lens[i] = tokens[i].size(); + total_size += token_lens[i]; + } + std::vector token_bytes(total_size); + size_t offset = 0; + for (size_t i = 0; i < tokens.size(); i++) { + std::copy(tokens[i].begin(), tokens[i].end(), token_bytes.data() + offset); + offset += token_lens[i]; + } + + LlgTokenizerInitV2 tok_init = {}; + tok_init.struct_size = sizeof(tok_init); + tok_init.vocab_size = (uint32_t)tokens.size(); + tok_init.tok_eos = eos_tokens[0]; + tok_init.token_lens = token_lens.data(); + tok_init.token_bytes = token_bytes.data(); + tok_init.tokenize_assumes_string = false; + tok_init.tokenize_user_data = tokenize_user_data; + tok_init.tokenize_fn = tokenize_fn; + if (eos_tokens.size() > 1) { + tok_init.tok_eos_extra = eos_tokens.data() + 1; + tok_init.tok_eos_extra_count = (uint32_t)(eos_tokens.size() - 1); + } + + char error_buf[128]; + auto tok = llg_new_tokenizer_v2(&tok_init, error_buf, sizeof(error_buf)); + + if (tok == nullptr) { + printf("Error (v2): %s\n", error_buf); + exit(1); + } + + return tok; +} + // Create an LlgTokenizer; tokens[token_id] is a byte sequence corresponding to // given token_id; see below for tokenize_fn LlgTokenizer *create_tokenizer(std::vector> &tokens, uint32_t tok_eos, LlgTokenizeFn tokenize_fn, const void *tokenize_user_data) { - auto token_lens = new uint32_t[tokens.size()]; + std::vector token_lens(tokens.size()); size_t total_size = 0; for (size_t i = 0; i < tokens.size(); i++) { token_lens[i] = tokens[i].size(); total_size += token_lens[i]; } - auto token_bytes = new uint8_t[total_size]; + std::vector token_bytes(total_size); size_t offset = 0; for (size_t i = 0; i < tokens.size(); i++) { - memcpy(token_bytes + offset, tokens[i].data(), token_lens[i]); + std::copy(tokens[i].begin(), tokens[i].end(), token_bytes.data() + offset); offset += token_lens[i]; } LlgTokenizerInit tok_init = {}; tok_init.vocab_size = (uint32_t)tokens.size(); tok_init.tok_eos = tok_eos; - tok_init.token_lens = token_lens; - tok_init.token_bytes = token_bytes; + tok_init.token_lens = token_lens.data(); + tok_init.token_bytes = token_bytes.data(); tok_init.tokenize_assumes_string = false; tok_init.tokenize_user_data = tokenize_user_data; tok_init.tokenize_fn = tokenize_fn; @@ -63,8 +108,8 @@ size_t tokenize_callback(const void *user_data, const uint8_t *bytes, (void)user_data; auto tokens = bogus_tokenize(bytes, bytes_len); if (output_tokens_len > 0) { - memcpy(output_tokens, tokens.data(), - std::min(output_tokens_len, tokens.size()) * sizeof(uint32_t)); + auto n = std::min(output_tokens_len, tokens.size()); + std::copy(tokens.begin(), tokens.begin() + n, output_tokens); } return tokens.size(); } @@ -72,6 +117,7 @@ size_t tokenize_callback(const void *user_data, const uint8_t *bytes, // This creates a tokenizer that treats each byte as a token. LlgTokenizer *create_byte_tokenizer(void) { std::vector> tokens; + tokens.reserve(257); // 256 byte tokens + 1 EOS // every byte is a token for (size_t i = 0; i < 256; i++) { tokens.push_back({(uint8_t)i}); @@ -82,6 +128,23 @@ LlgTokenizer *create_byte_tokenizer(void) { nullptr); } +// Same as above but using the v2 API with an extra (unused) EOS token. +LlgTokenizer *create_byte_tokenizer_v2(void) { + std::vector> tokens; + tokens.reserve(258); // 256 byte tokens + 2 EOS + for (size_t i = 0; i < 256; i++) { + tokens.push_back({(uint8_t)i}); + } + const char *eos = ""; + tokens.push_back(std::vector(eos, eos + strlen(eos))); + const char *eos2 = ""; + tokens.push_back(std::vector(eos2, eos2 + strlen(eos2))); + // EOS tokens: token 256 () is primary, token 257 () is extra + std::vector eos_tokens = {(uint32_t)(tokens.size() - 2), + (uint32_t)(tokens.size() - 1)}; + return create_tokenizer_v2(tokens, eos_tokens, tokenize_callback, nullptr); +} + LlgTokenizer *create_hf_tokenizer(std::string tokenizer_json, uint32_t tok_eos) { LlgTokenizerInit tok_init = {}; @@ -141,21 +204,8 @@ std::string do_llg_stringify_tokens(const LlgTokenizer *tok, } } -int main(int argc, const char *argv[]) { - if (argc < 3) { - printf("Usage: %s [tokenizer.json]\n", - argv[0]); - return 1; - } - - // the tokenizer can (and should) be shared between constraints - LlgTokenizer *tokenizer = argc > 3 - ? create_hf_tokenizer(read_file(argv[3]), 2) - : create_byte_tokenizer(); - - auto schema_json = read_file(argv[1]); - auto sample_json = read_file(argv[2]); - +void run_constraint_test(LlgTokenizer *tokenizer, const std::string &schema_json, + const std::string &sample_json, const char *label) { LlgConstraintInit init; llg_constraint_init_set_defaults(&init, tokenizer); init.log_stderr_level = 0; // default to 1 (warnings only) @@ -167,14 +217,6 @@ int main(int argc, const char *argv[]) { fail_constraint(c); } - // for debugging the tokenizer: - // for (int i = 0; i < 320; ++i) { - // std::vector tokens; - // tokens.push_back(i); - // std::string s = do_llg_stringify_tokens(tokenizer, tokens); - // printf("Token %d: %s\n", i, s.c_str()); - // } - // we assume our "LLM" will generate these tokens auto tokens = do_llg_tokenize(tokenizer, sample_json); @@ -212,6 +254,35 @@ int main(int argc, const char *argv[]) { // we assume the constraint will force EOS at the end of the input assert(mask_res.is_stop); - printf("OK!\n"); + llg_free_constraint(c); + printf("%s: OK!\n", label); +} + +int main(int argc, const char *argv[]) { + if (argc < 3) { + printf("Usage: %s [tokenizer.json]\n", + argv[0]); + return 1; + } + + auto schema_json = read_file(argv[1]); + auto sample_json = read_file(argv[2]); + + // Test with v1 API (LlgTokenizerInit + llg_new_tokenizer) + { + LlgTokenizer *tokenizer = argc > 3 + ? create_hf_tokenizer(read_file(argv[3]), 2) + : create_byte_tokenizer(); + run_constraint_test(tokenizer, schema_json, sample_json, "v1"); + llg_free_tokenizer(tokenizer); + } + + // Test with v2 API (LlgTokenizerInitV2 + llg_new_tokenizer_v2) + { + LlgTokenizer *tokenizer = create_byte_tokenizer_v2(); + run_constraint_test(tokenizer, schema_json, sample_json, "v2"); + llg_free_tokenizer(tokenizer); + } + return 0; } diff --git a/parser/llguidance.h b/parser/llguidance.h index fa102c12..0d6c494c 100644 --- a/parser/llguidance.h +++ b/parser/llguidance.h @@ -186,6 +186,11 @@ typedef size_t (*LlgTokenizeFn)(const void *user_data, uint32_t *output_tokens, size_t output_tokens_len); +/** + * This struct must be zero-initialized (e.g., `= {}` in C/C++) before setting fields. + * New fields may be appended in future versions, and zero-initialization ensures + * they receive safe default values. + */ typedef struct LlgTokenizerInit { /** * The number of tokens in the vocabulary @@ -241,6 +246,87 @@ typedef struct LlgTokenizerInit { const char *const *slices; } LlgTokenizerInit; +/** + * V2 of the tokenizer initialization struct. + * Extends LlgTokenizerInit with support for multiple EOS tokens. + * Use with `llg_new_tokenizer_v2()`. + * + * Initialize with: `LlgTokenizerInitV2 init = {}; init.struct_size = sizeof(init);` + * The library only reads `struct_size` bytes from the pointer, so callers + * compiled against an older header (with a smaller struct) will work with + * newer library versions — any new fields default to zero. + */ +typedef struct LlgTokenizerInitV2 { + /** + * Must be set to `sizeof(LlgTokenizerInitV2)`. + * The library uses this to determine how many bytes to read, enabling + * forward compatibility when new fields are appended in future versions. + */ + size_t struct_size; + /** + * The number of tokens in the vocabulary + */ + uint32_t vocab_size; + /** + * The token ID for the end of sentence token + * For chat mode, set it to end-of-turn token + */ + LlgToken tok_eos; + /** + * An array of the lengths of the token strings (vocab_size elements) + */ + const uint32_t *token_lens; + /** + * A pointer to the token strings + * The length of this the sum of all token_lens + */ + const uint8_t *token_bytes; + /** + * Instead of passing token_lens and token_bytes, this can be set to + * the contents of HF tokenizer.json file. + */ + const char *tokenizer_json; + /** + * Set to true to enable hack that works around the tokenize_fn only + * accepting valid UTF-8 strings and possibly adding `` etc. + * TODO: the `` bit not implemented yet + */ + bool tokenize_assumes_string; + /** + * Tokenization function, see LlgTokenizeFn docs. + * It should only tokenize the bytes and not add + * any `` etc. It should also work on any byte sequence, including + * invalid UTF-8. If this is not the case, set tokenize_assumes_string to true. + * Either way, this function has to be thread-safe! + */ + LlgTokenizeFn tokenize_fn; + /** + * Set to true to not use tokenize_fn and instead tokenize greedily, + * which is often incorrect and may reduce accuracy. + */ + bool use_approximate_greedy_tokenize_fn; + /** + * User data to pass to the tokenize_fn + */ + const void *tokenize_user_data; + /** + * Tokenizer partitions for the slicer optimization. + * This is array of pointers to strings, terminated with NULL (argv style). + * Pass NULL to use defaults. Pass empty array to disable. + */ + const char *const *slices; + /** + * Additional EOS token IDs beyond `tok_eos`. + * Points to an array of `tok_eos_extra_count` elements. + * When NULL (the default for zero-initialized structs), only `tok_eos` is used. + */ + const LlgToken *tok_eos_extra; + /** + * Number of elements in the `tok_eos_extra` array. + */ + uint32_t tok_eos_extra_count; +} LlgTokenizerInitV2; + #ifdef __cplusplus @@ -347,6 +433,25 @@ struct LlgTokenizer *llg_new_tokenizer(const struct LlgTokenizerInit *tok_init, char *error_string, size_t error_string_len); +/** + * Create a new tokenizer from a LlgTokenizerInitV2 struct. + * This is the v2 API that supports multiple EOS tokens. + * + * The `tok_init` pointer must be valid and `tok_init->struct_size` must be set + * to `sizeof(LlgTokenizerInitV2)` as known by the caller. The library will + * only read `struct_size` bytes, so callers compiled against an older (smaller) + * version of the struct will work with newer library versions — new fields + * default to zero. + * + * `tok_init` must point to at least `tok_init->struct_size` bytes of + * initialized memory, and `struct_size` must be at least + * `offsetof(LlgTokenizerInitV2, token_lens)` (i.e., include struct_size, + * vocab_size, and the complete tok_eos field). + */ +struct LlgTokenizer *llg_new_tokenizer_v2(const struct LlgTokenizerInitV2 *tok_init, + char *error_string, + size_t error_string_len); + /** * Clone a tokenizer. * This increments a reference count and does a small allocation. diff --git a/parser/src/ffi.rs b/parser/src/ffi.rs index c1483c9a..dc610d34 100644 --- a/parser/src/ffi.rs +++ b/parser/src/ffi.rs @@ -100,6 +100,10 @@ unsafe fn slice_from_ptr_or_empty<'a, T>(data: *const T, len: usize) -> &'a [T] impl LlgTokenizer { fn from_init(init: &LlgTokenizerInit) -> Result { + Self::from_init_v2(&LlgTokenizerInitV2::from_v1(init)) + } + + fn from_init_v2(init: &LlgTokenizerInitV2) -> Result { ensure!( init.tokenize_fn.is_some() || init.use_approximate_greedy_tokenize_fn, "Either tokenize_fn or use_approximate_greedy_tokenize_fn must be set" @@ -137,7 +141,26 @@ impl LlgTokenizer { token_bytes }; - let trie = TokTrie::from(&TokRxInfo::new(tokens.len() as u32, init.tok_eos), &tokens); + let mut trie = TokTrie::from(&TokRxInfo::new(tokens.len() as u32, init.tok_eos), &tokens); + + // Apply additional EOS tokens if provided + if !init.tok_eos_extra.is_null() && init.tok_eos_extra_count > 0 { + let extra = unsafe { + std::slice::from_raw_parts(init.tok_eos_extra, init.tok_eos_extra_count as usize) + }; + let mut eos_tokens = vec![init.tok_eos]; + eos_tokens.extend_from_slice(extra); + + let vocab_size = trie.vocab_size() as u32; + for &id in &eos_tokens { + ensure!( + id < vocab_size, + "EOS token ID {id} is out of range (vocab_size={vocab_size})" + ); + } + + trie = trie.with_eos_tokens(&eos_tokens); + } let tok_env: TokEnv = Arc::new(CTokenizerInner { trie, @@ -206,6 +229,9 @@ pub type LlgTokenizeFn = Option< /// Function which llg calls when an operation is done. pub type LlgCallback = Option; +/// This struct must be zero-initialized (e.g., `= {}` in C/C++) before setting fields. +/// New fields may be appended in future versions, and zero-initialization ensures +/// they receive safe default values. #[repr(C)] pub struct LlgTokenizerInit { /// The number of tokens in the vocabulary @@ -251,6 +277,92 @@ pub struct LlgTokenizerInit { pub slices: *const *const c_char, } +/// V2 of the tokenizer initialization struct. +/// Extends LlgTokenizerInit with support for multiple EOS tokens. +/// Use with `llg_new_tokenizer_v2()`. +/// +/// Initialize with: `LlgTokenizerInitV2 init = {}; init.struct_size = sizeof(init);` +/// The library only reads `struct_size` bytes from the pointer, so callers +/// compiled against an older header (with a smaller struct) will work with +/// newer library versions — any new fields default to zero. +#[repr(C)] +pub struct LlgTokenizerInitV2 { + /// Must be set to `sizeof(LlgTokenizerInitV2)`. + /// The library uses this to determine how many bytes to read, enabling + /// forward compatibility when new fields are appended in future versions. + pub struct_size: usize, + + /// The number of tokens in the vocabulary + pub vocab_size: u32, + + /// The token ID for the end of sentence token + /// For chat mode, set it to end-of-turn token + pub tok_eos: LlgToken, + + /// An array of the lengths of the token strings (vocab_size elements) + pub token_lens: *const u32, + + /// A pointer to the token strings + /// The length of this the sum of all token_lens + pub token_bytes: *const u8, + + /// Instead of passing token_lens and token_bytes, this can be set to + /// the contents of HF tokenizer.json file. + pub tokenizer_json: *const c_char, + + /// Set to true to enable hack that works around the tokenize_fn only + /// accepting valid UTF-8 strings and possibly adding `` etc. + /// TODO: the `` bit not implemented yet + pub tokenize_assumes_string: bool, + + /// Tokenization function, see LlgTokenizeFn docs. + /// It should only tokenize the bytes and not add + /// any `` etc. It should also work on any byte sequence, including + /// invalid UTF-8. If this is not the case, set tokenize_assumes_string to true. + /// Either way, this function has to be thread-safe! + pub tokenize_fn: LlgTokenizeFn, + + /// Set to true to not use tokenize_fn and instead tokenize greedily, + /// which is often incorrect and may reduce accuracy. + pub use_approximate_greedy_tokenize_fn: bool, + + /// User data to pass to the tokenize_fn + pub tokenize_user_data: *const c_void, + + /// Tokenizer partitions for the slicer optimization. + /// This is array of pointers to strings, terminated with NULL (argv style). + /// Pass NULL to use defaults. Pass empty array to disable. + pub slices: *const *const c_char, + + /// Additional EOS token IDs beyond `tok_eos`. + /// Points to an array of `tok_eos_extra_count` elements. + /// When NULL (the default for zero-initialized structs), only `tok_eos` is used. + pub tok_eos_extra: *const LlgToken, + + /// Number of elements in the `tok_eos_extra` array. + pub tok_eos_extra_count: u32, +} + +impl LlgTokenizerInitV2 { + fn from_v1(v1: &LlgTokenizerInit) -> Self { + LlgTokenizerInitV2 { + struct_size: std::mem::size_of::(), + vocab_size: v1.vocab_size, + tok_eos: v1.tok_eos, + token_lens: v1.token_lens, + token_bytes: v1.token_bytes, + tokenizer_json: v1.tokenizer_json, + tokenize_assumes_string: v1.tokenize_assumes_string, + tokenize_fn: v1.tokenize_fn, + use_approximate_greedy_tokenize_fn: v1.use_approximate_greedy_tokenize_fn, + tokenize_user_data: v1.tokenize_user_data, + slices: v1.slices, + tok_eos_extra: std::ptr::null(), + tok_eos_extra_count: 0, + } + } +} + #[derive(Clone)] #[repr(C)] pub struct LlgConstraintInit { @@ -669,6 +781,71 @@ pub unsafe extern "C" fn llg_new_tokenizer( } } +/// Create a new tokenizer from a LlgTokenizerInitV2 struct. +/// This is the v2 API that supports multiple EOS tokens. +/// +/// The `tok_init` pointer must be valid and `tok_init->struct_size` must be set +/// to `sizeof(LlgTokenizerInitV2)` as known by the caller. The library will +/// only read `struct_size` bytes, so callers compiled against an older (smaller) +/// version of the struct will work with newer library versions — new fields +/// default to zero. +/// +/// # Safety +/// `tok_init` must point to at least `tok_init->struct_size` bytes of +/// initialized memory, and `struct_size` must be at least +/// `offsetof(LlgTokenizerInitV2, token_lens)` (i.e., include struct_size, +/// vocab_size, and the complete tok_eos field). +#[no_mangle] +pub unsafe extern "C" fn llg_new_tokenizer_v2( + tok_init: *const LlgTokenizerInitV2, + error_string: *mut c_char, + error_string_len: usize, +) -> *mut LlgTokenizer { + if tok_init.is_null() { + save_error_string( + anyhow::anyhow!("tok_init is NULL"), + error_string, + error_string_len, + ); + return std::ptr::null_mut(); + } + + // Read struct_size from the first field (always safe if pointer is valid) + let struct_size = unsafe { std::ptr::read(tok_init as *const usize) }; + let min_size = std::mem::offset_of!(LlgTokenizerInitV2, token_lens); + if struct_size < min_size { + save_error_string( + anyhow::anyhow!( + "LlgTokenizerInitV2.struct_size is {struct_size} but expected at least {min_size}. \ + Set struct_size = sizeof(LlgTokenizerInitV2)." + ), + error_string, + error_string_len, + ); + return std::ptr::null_mut(); + } + + // Copy the caller's data into a zero-initialized local struct. + // Fields beyond what the caller provides default to zero. + let mut local: LlgTokenizerInitV2 = unsafe { std::mem::zeroed() }; + let copy_size = std::cmp::min(struct_size, std::mem::size_of::()); + unsafe { + std::ptr::copy_nonoverlapping( + tok_init as *const u8, + &mut local as *mut LlgTokenizerInitV2 as *mut u8, + copy_size, + ); + } + + match LlgTokenizer::from_init_v2(&local) { + Ok(tok) => Box::into_raw(Box::new(tok)), + Err(e) => { + save_error_string(e, error_string, error_string_len); + std::ptr::null_mut() + } + } +} + /// Clone a tokenizer. /// This increments a reference count and does a small allocation. #[no_mangle] diff --git a/parser/src/matcher.rs b/parser/src/matcher.rs index b11660ee..f49f45ad 100644 --- a/parser/src/matcher.rs +++ b/parser/src/matcher.rs @@ -109,8 +109,7 @@ impl Matcher { pub fn compute_mask_or_eos(&mut self) -> Result { self.with_inner(|inner| { if inner.parser.stop_reason() != StopReason::NotStopped { - let trie = inner.parser.token_env.tok_trie(); - Ok(trie.singleton_token_set(trie.eos_token())) + Ok(inner.parser.token_env.tok_trie().eos_token_set()) } else { inner.parser.compute_mask() } diff --git a/parser/src/tokenparser.rs b/parser/src/tokenparser.rs index 28deeca9..6e47113d 100644 --- a/parser/src/tokenparser.rs +++ b/parser/src/tokenparser.rs @@ -21,7 +21,7 @@ pub struct TokenParser { pub dbg_grammar: String, last_step_stats: ParserStats, max_step_stats: ParserStats, - eos_token: TokenId, + eos_tokens: Vec, had_rollback: bool, had_backtrack: bool, @@ -91,7 +91,7 @@ impl TokenParser { factory.perf_counters(), )?; parser.metrics_mut().rand = factory.next_rng(); - let eos_token = token_env.tok_trie().eos_token(); + let eos_tokens = token_env.tok_trie().eos_tokens().to_vec(); Ok(TokenParser { bias_computer: factory.slicer().clone(), @@ -108,7 +108,7 @@ impl TokenParser { error_message: None, parser, dbg_grammar: String::new(), - eos_token, + eos_tokens, llm_tokens: Vec::new(), llm_bytes: Vec::new(), grm_prefix: Vec::new(), @@ -393,7 +393,7 @@ impl TokenParser { let new_len = self.llm_tokens.len() - n_tokens; let mut bytes_to_drop = 0; for tok in &self.llm_tokens[new_len..] { - if *tok == self.eos_token { + if self.eos_tokens.contains(tok) { // doesn't count; we hope it's last though... bytes_to_drop += 0; } else { @@ -496,8 +496,12 @@ impl TokenParser { return Err(self.stop_for_parser_error("", s)); } - if self.eos_token != INVALID_TOKEN && self.is_accepting() { - allowed_tokens.allow_token(self.eos_token); + if self.is_accepting() { + for &eos in &self.eos_tokens { + if eos != INVALID_TOKEN { + allowed_tokens.allow_token(eos); + } + } } self.log_final(&prefix, &allowed_tokens); @@ -801,7 +805,7 @@ impl TokenParser { } self.max_tokens_total -= 1; - if token == self.eos_token { + if self.eos_tokens.contains(&token) { if self.parser.scan_eos() { // it got scanned correctly, so we remove it // this only happens for gen() terminated by EOS @@ -842,7 +846,10 @@ impl TokenParser { /// This generally should be called after consume_token(). pub fn check_stop(&mut self) -> Result { let empty_token_prefix = !self.has_ff_bytes(); - let pending_eos = self.llm_tokens.last() == Some(&self.eos_token); + let pending_eos = self + .llm_tokens + .last() + .is_some_and(|t| self.eos_tokens.contains(t)); let lexer_bytes = self.parser.has_pending_lexeme_bytes(); let is_accepting = self.is_accepting(); let can_advance = self.parser.can_advance(); diff --git a/python/llguidance/_lib.pyi b/python/llguidance/_lib.pyi index c2156c11..be09f478 100644 --- a/python/llguidance/_lib.pyi +++ b/python/llguidance/_lib.pyi @@ -6,13 +6,14 @@ from ._tokenizer import TokenizerWrapper class LLTokenizer: vocab_size: int eos_token: TokenId + eos_tokens: List[TokenId] is_canonical: bool def __new__( cls, tokenizer: Union[str, TokenizerWrapper], n_vocab: Optional[int] = None, - eos_token: Optional[TokenId] = None, + eos_token: Optional[Union[TokenId, List[TokenId]]] = None, slices: Optional[List[str]] = None, ) -> "LLTokenizer": """ @@ -23,6 +24,7 @@ class LLTokenizer: Args: tokenizer: str or TokenizerWrapper - if str, it is the name or path to the HF tokenizers tokenizer; otherwise it is a TokenizerWrapper n_vocab: int - override the size of the vocabulary + eos_token: int or list of ints - override the EOS token(s) slices: List[str] - configuration for slicer optimization; pass [] to disable, or None to use general_slices() """ diff --git a/python/llguidance/hf.py b/python/llguidance/hf.py index 11da7d45..05c78c03 100644 --- a/python/llguidance/hf.py +++ b/python/llguidance/hf.py @@ -1,5 +1,5 @@ from copy import copy -from typing import List, Optional +from typing import List, Optional, Union import transformers @@ -9,7 +9,7 @@ def from_tokenizer( hf_tokenizer: transformers.PreTrainedTokenizerFast, n_vocab: Optional[int] = None, - eos_token: Optional[int] = None, + eos_token: Optional[Union[int, List[int]]] = None, slices: Optional[List[str]] = None, ) -> LLTokenizer: """ @@ -21,7 +21,7 @@ def from_tokenizer( Args: hf_tokenizer: transformers.PreTrainedTokenizerFast - the tokenizer to wrap n_vocab: int - override the size of the vocabulary - eos_token: int - override the EOS token + eos_token: int or list of ints - override the EOS token(s) slices: List[str] - configuration for slicer optimization; pass [] to disable, or None to use the default configuration """ diff --git a/python/llguidance/llamacpp.py b/python/llguidance/llamacpp.py index 7d495d5d..5e2cedc3 100644 --- a/python/llguidance/llamacpp.py +++ b/python/llguidance/llamacpp.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union from ._lib import LLTokenizer @@ -8,7 +8,7 @@ def lltokenizer_from_vocab( vocab: llama_cpp.llama_vocab_p, n_vocab: Optional[int] = None, - eos_token: Optional[int] = None, + eos_token: Optional[Union[int, List[int]]] = None, slices: Optional[List[str]] = None, ) -> LLTokenizer: """ @@ -18,7 +18,8 @@ def lltokenizer_from_vocab( Args: vocab: llama_cpp.llama_vocab_p - the vocab object to use n_vocab: int - override the size of the vocabulary - eos_token: int - override the EOS token + eos_token: int or list of ints - override the EOS token(s) + slices: List[str] - configuration for slicer optimization; pass [] to disable, or None to use the default configuration """ diff --git a/python/llguidance/tiktoken.py b/python/llguidance/tiktoken.py index 4e4c6668..dced8f97 100644 --- a/python/llguidance/tiktoken.py +++ b/python/llguidance/tiktoken.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING from ._lib import LLTokenizer @@ -10,7 +10,7 @@ def lltokenizer_from_encoding( encoding: 'tiktoken.Encoding', *, n_vocab: Optional[int] = None, - eos_token: Optional[int] = None, + eos_token: Optional[Union[int, List[int]]] = None, slices: Optional[List[str]] = None, ) -> LLTokenizer: """ @@ -20,7 +20,7 @@ def lltokenizer_from_encoding( Args: encoding: tiktoken.Encoding - the encoding object to use n_vocab: int - override the size of the vocabulary - eos_token: int - override the EOS token + eos_token: int or list of ints - override the EOS token(s) slices: List[str] - configuration for slicer optimization; pass [] to disable, or None to use the default configuration """ diff --git a/python/torch_tests/test_matcher.py b/python/torch_tests/test_matcher.py index ede9fc50..8a8194fe 100644 --- a/python/torch_tests/test_matcher.py +++ b/python/torch_tests/test_matcher.py @@ -1,16 +1,16 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple + import llguidance +import numpy as np +import pytest +from llguidance import LLMatcher, LLParserLimits, LLTokenizer, StructTag from llguidance.numpy import ( - fill_next_token_bitmask_par, - fill_next_token_bitmask_par_with_draft_tokens, allocate_token_bitmask, consume_token_par, + fill_next_token_bitmask_par, + fill_next_token_bitmask_par_with_draft_tokens, ) - -from llguidance import LLMatcher, LLTokenizer, StructTag, LLParserLimits -import pytest from numpy.typing import NDArray -import numpy as np _tokenizer = None @@ -685,3 +685,45 @@ def test_get_capture() -> None: assert m.get_capture("body") == b"1234" assert m.get_capture("non-existent-group") is None assert m.get_captures() == [("group1", b"w"), ("body", b"orld"), ("body", b"abcd"), ("body", b"1234")] + + +def test_multi_eos_tokens_property() -> None: + """Test that eos_tokens returns the expected list for a single-EOS tokenizer.""" + tok = LLTokenizer("byte") + assert tok.eos_tokens == [tok.eos_token] + assert len(tok.eos_tokens) == 1 + + + +def test_multi_eos_wrapper_override() -> None: + """Test that eos_token override works with TokenizerWrapper path.""" + + class MockTokenizerWrapper: + """Minimal mock that satisfies the TokenizerWrapper interface for testing.""" + + def __init__(self, tokens: List[bytes], eos_token_id: int): + self.tokens = tokens + self.eos_token_id = eos_token_id + self.bos_token_id = None + self.special_token_ids: List[int] = [] + self.is_tokenizer_wrapper = True + + def __call__(self, s: str) -> List[int]: + return [b for b in s.encode("utf-8")] + + # Create a minimal byte-level tokenizer with 258 tokens: + # tokens 0-255 are single bytes, 256 is , 257 is + tokens = [bytes([i]) for i in range(256)] + tokens.append(b"") + tokens.append(b"") + wrapper = MockTokenizerWrapper(tokens, eos_token_id=256) + + # Without override: single EOS + tok1 = LLTokenizer(wrapper) # type: ignore[arg-type] + assert tok1.eos_token == 256 + assert tok1.eos_tokens == [256] + + # With override: multiple EOS + tok2 = LLTokenizer(wrapper, eos_token=[256, 257]) # type: ignore[arg-type] + assert tok2.eos_token == 256 + assert tok2.eos_tokens == [256, 257] diff --git a/python_ext/src/llamatokenizer.rs b/python_ext/src/llamatokenizer.rs index 15e92445..7af93931 100644 --- a/python_ext/src/llamatokenizer.rs +++ b/python_ext/src/llamatokenizer.rs @@ -119,13 +119,25 @@ pub fn tokenv_from_llamacpp( tokens: Vec>, vocab_ptr: usize, tokenize_fptr: usize, - eos_token: u32, + eos_tokens: &[u32], ) -> Result { + ensure!(!eos_tokens.is_empty(), "eos_tokens must not be empty"); ensure!(vocab_ptr != 0, "vocab_ptr must be non-null"); ensure!(tokenize_fptr != 0, "tokenize_fptr must be non-null"); - let info = TokRxInfo::new(tokens.len() as u32, eos_token); - let trie = TokTrie::from(&info, &tokens); + let vocab_size = tokens.len() as u32; + for &id in eos_tokens { + ensure!( + id < vocab_size, + "EOS token ID {id} is out of range (vocab_size={vocab_size})" + ); + } + + let info = TokRxInfo::new(tokens.len() as u32, eos_tokens[0]); + let mut trie = TokTrie::from(&info, &tokens); + if eos_tokens.len() > 1 { + trie = trie.with_eos_tokens(eos_tokens); + } let mut llama_tok = LlamaTokenizer { trie, diff --git a/python_ext/src/llmatcher.rs b/python_ext/src/llmatcher.rs index e4c17a87..636c89e5 100644 --- a/python_ext/src/llmatcher.rs +++ b/python_ext/src/llmatcher.rs @@ -265,8 +265,7 @@ impl LLMatcher { } fn eos_token_set(&self) -> SimpleVob { - let trie = self.tok_env.tok_trie(); - trie.singleton_token_set(trie.eos_token()) + self.tok_env.tok_trie().eos_token_set() } fn compute_mask_or_eos(&mut self) -> SimpleVob { @@ -280,7 +279,13 @@ impl LLMatcher { } fn consume_token_inner(&mut self, sampled_token: TokenId) -> bool { - if self.inner.is_stopped() && sampled_token == self.tok_env.tok_trie().eos_token() { + if self.inner.is_stopped() + && self + .tok_env + .tok_trie() + .eos_tokens() + .contains(&sampled_token) + { true } else { self.inner.consume_token(sampled_token).is_ok() diff --git a/python_ext/src/py.rs b/python_ext/src/py.rs index ec595432..337fa8ef 100644 --- a/python_ext/src/py.rs +++ b/python_ext/src/py.rs @@ -16,6 +16,35 @@ use toktrie_tiktoken::TikTokenBPE; use crate::llamatokenizer::tokenv_from_llamacpp; +/// Extract EOS tokens from a Python value that must be an int or a non-empty list[int]. +/// Returns a Vec on success, or raises PyValueError if the value is invalid or the list is empty. +fn extract_eos_tokens(obj: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(single) = obj.extract::() { + Ok(vec![single]) + } else if let Ok(list) = obj.extract::>() { + if list.is_empty() { + return Err(PyValueError::new_err("eos_token list must not be empty")); + } + Ok(list) + } else { + Err(PyValueError::new_err( + "eos_token must be an int or a non-empty list of ints", + )) + } +} + +/// Validate that all EOS token IDs are within vocab range. +fn validate_eos_tokens(eos_tokens: &[u32], vocab_size: u32) -> PyResult<()> { + for &id in eos_tokens { + if id >= vocab_size { + return Err(PyValueError::new_err(format!( + "EOS token ID {id} is out of range (vocab_size={vocab_size})" + ))); + } + } + Ok(()) +} + struct PyTokenizer { tok_trie: Arc, tokenizer_fun: Py, @@ -36,9 +65,10 @@ impl LLTokenizer { fn py_new( tokenizer: Bound<'_, PyAny>, n_vocab: Option, - eos_token: Option, + eos_token: Option>, slices: Option>, ) -> PyResult { + let eos_tokens = eos_token.as_ref().map(extract_eos_tokens).transpose()?; let tok_env: TokEnv = if let Ok(tokenizer_str) = tokenizer.extract::() { if tokenizer_str == "byte" { ApproximateTokEnv::single_byte_env() @@ -48,13 +78,19 @@ impl LLTokenizer { } else { ByteTokenizer::from_file(&tokenizer_str).map_err(val_error)? }; - if let Some(eos_token) = eos_token { - tok.set_eos_token(eos_token); + if let Some(ref eos_tokens) = eos_tokens { + validate_eos_tokens(eos_tokens, tok.tokrx_info().vocab_size)?; + tok.set_eos_tokens(eos_tokens); } tok.into_tok_env(n_vocab).map_err(val_error)? } } else { - Arc::new(PyTokenizer::py_new(tokenizer)?) + let mut py_tok = PyTokenizer::py_new(tokenizer)?; + if let Some(ref eos_tokens) = eos_tokens { + validate_eos_tokens(eos_tokens, py_tok.tok_trie.vocab_size() as u32)?; + py_tok.tok_trie = Arc::new(py_tok.tok_trie.with_eos_tokens(eos_tokens)); + } + Arc::new(py_tok) }; let factory = ParserFactory::new( &tok_env, @@ -77,18 +113,23 @@ impl LLTokenizer { encoder: HashMap, u32>, special_tokens: HashMap, pattern: &str, - eos_token: u32, + eos_token: Bound<'_, PyAny>, n_vocab: Option, slices: Option>, ) -> PyResult { - let bpe = TikTokenBPE::new( + let eos_tokens = extract_eos_tokens(&eos_token)?; + let mut bpe = TikTokenBPE::new( encoder.into_iter().collect(), special_tokens.into_iter().collect(), pattern, n_vocab, - eos_token, + eos_tokens[0], ) .map_err(val_error)?; + validate_eos_tokens(&eos_tokens, bpe.tokrx_info().vocab_size)?; + if eos_tokens.len() > 1 { + bpe.set_eos_tokens(&eos_tokens); + } let tok_env = bpe.to_env(); let factory = ParserFactory::new( @@ -108,11 +149,12 @@ impl LLTokenizer { tokens: Vec>, vocab_ptr: usize, tokenize_fptr: usize, - eos_token: u32, + eos_token: Bound<'_, PyAny>, slices: Option>, ) -> PyResult { - let tok_env = - tokenv_from_llamacpp(tokens, vocab_ptr, tokenize_fptr, eos_token).map_err(val_error)?; + let eos_tokens = extract_eos_tokens(&eos_token)?; + let tok_env = tokenv_from_llamacpp(tokens, vocab_ptr, tokenize_fptr, &eos_tokens) + .map_err(val_error)?; let factory = ParserFactory::new( &tok_env, @@ -244,6 +286,11 @@ impl LLTokenizer { fn eos_token(&self) -> u32 { self.tok_trie().eos_token() } + + #[getter] + fn eos_tokens(&self) -> Vec { + self.tok_trie().eos_tokens().to_vec() + } } impl LLTokenizer { diff --git a/sample_parser/tests/test_raw_parser.rs b/sample_parser/tests/test_raw_parser.rs index f67dceff..b4a39332 100644 --- a/sample_parser/tests/test_raw_parser.rs +++ b/sample_parser/tests/test_raw_parser.rs @@ -2,10 +2,11 @@ use lazy_static::lazy_static; use llguidance::{ api::TopLevelGrammar, earley::SlicedBiasComputer, - toktrie::{InferenceCapabilities, TokEnv}, + toktrie::{ApproximateTokEnv, InferenceCapabilities, TokEnv, TokenizerEnv}, Matcher, ParserFactory, TokenParser, }; use serde_json::{json, Value}; +use std::sync::Arc; lazy_static! { static ref PARSER_FACTORY_PHI: ParserFactory = { @@ -355,3 +356,45 @@ fn test_try_consume_eos_consistency() { assert!(eos_consumed <= 1); assert_eq!(n_consumed_no_eos + eos_consumed, n_consumed_all); } + +#[test] +fn test_multi_eos_mask_when_stopped() { + // Build a byte-level tokenizer with two EOS tokens + let base = ApproximateTokEnv::single_byte(); + let base_trie = base.tok_trie(); + let primary_eos = base_trie.eos_token(); + // Pick a special token as the second EOS + let extra_eos = primary_eos - 1; + let multi_trie = base_trie.clone().with_eos_tokens(&[primary_eos, extra_eos]); + let tok_env: TokEnv = Arc::new(ApproximateTokEnv::new(multi_trie)); + + let factory = ParserFactory::new( + &tok_env, + InferenceCapabilities::default(), + &SlicedBiasComputer::general_slices(), + ) + .unwrap(); + + let grm = TopLevelGrammar::from_lark(r#"start: "a""#.to_string()); + let mut parser = factory.create_parser(grm).unwrap(); + parser.start_without_prompt(); + let mut matcher = Matcher::new(Ok(parser)); + + // Consume "a" — grammar should accept + let mask = matcher.compute_mask().unwrap(); + assert!(mask.is_allowed(b'a' as u32)); + matcher.consume_token(b'a' as u32).unwrap(); + + // Parser stops after accepting the full input. + // compute_mask_or_eos should include BOTH EOS tokens. + let mask = matcher.compute_mask_or_eos().unwrap(); + assert!( + mask.is_allowed(primary_eos), + "primary EOS should be in stopped mask" + ); + assert!( + mask.is_allowed(extra_eos), + "extra EOS should be in stopped mask" + ); + assert!(matcher.is_stopped()); +} diff --git a/toktrie/src/toktree.rs b/toktrie/src/toktree.rs index 5c542e39..6f04347d 100644 --- a/toktrie/src/toktree.rs +++ b/toktrie/src/toktree.rs @@ -101,6 +101,7 @@ pub struct TokTrie { token_data: Vec, nodes: Vec, max_token_len: usize, + eos_tokens: Vec, } #[derive(Clone, Copy, Zeroable, Pod)] @@ -194,6 +195,7 @@ impl TokTrie { token_data, nodes, max_token_len, + eos_tokens: vec![info.tok_eos], }; r.validate(); r @@ -209,19 +211,34 @@ impl TokTrie { }; words.push(b.to_vec()); } - Self::from(self.info(), &words) + let mut r = Self::from(self.info(), &words); + r.eos_tokens = self.eos_tokens.clone(); + r } pub fn with_eos_token(&self, eos_token: TokenId) -> Self { - self.with_info(TokRxInfo { - tok_eos: eos_token, - ..self.info - }) + self.with_eos_tokens(&[eos_token]) + } + + pub fn with_eos_tokens(&self, eos_tokens: &[TokenId]) -> Self { + assert!(!eos_tokens.is_empty(), "eos_tokens must not be empty"); + let vocab = self.vocab_size() as u32; + for &tok in eos_tokens { + assert!( + tok < vocab, + "EOS token ID {tok} is out of range (vocab_size={vocab})" + ); + } + let mut r = self.clone(); + r.info.tok_eos = eos_tokens[0]; + r.eos_tokens = eos_tokens.to_vec(); + r } pub fn with_info(&self, info: TokRxInfo) -> Self { let mut r = self.clone(); r.info = info; + r.eos_tokens = vec![info.tok_eos]; r } @@ -248,6 +265,10 @@ impl TokTrie { self.info.tok_eos } + pub fn eos_tokens(&self) -> &[TokenId] { + &self.eos_tokens + } + pub fn vocab_size(&self) -> usize { self.info.vocab_size as usize } @@ -262,6 +283,18 @@ impl TokTrie { r } + /// Returns a token set containing all EOS tokens. + pub fn eos_token_set(&self) -> SimpleVob { + let mut r = self.alloc_token_set(); + let vocab = self.vocab_size() as u32; + for &eos in self.eos_tokens() { + if eos != INVALID_TOKEN && eos < vocab { + r.allow_token(eos); + } + } + r + } + pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { let max_examples = 50; @@ -1189,3 +1222,82 @@ impl Recognizer for AnythingGoes { true } } + +#[cfg(test)] +mod tests { + use super::*; + + fn make_test_trie(eos: TokenId) -> TokTrie { + let info = TokRxInfo::new(4, eos); + let words = vec![b"a".to_vec(), b"b".to_vec(), b"c".to_vec(), b"d".to_vec()]; + TokTrie::from(&info, &words) + } + + #[test] + fn test_default_single_eos() { + let trie = make_test_trie(2); + assert_eq!(trie.eos_token(), 2); + assert_eq!(trie.eos_tokens(), &[2]); + } + + #[test] + fn test_with_eos_tokens_multiple() { + let trie = make_test_trie(0).with_eos_tokens(&[1, 3]); + assert_eq!(trie.eos_token(), 1); + assert_eq!(trie.eos_tokens(), &[1, 3]); + assert_eq!(trie.info().tok_eos, 1); + } + + #[test] + fn test_with_eos_token_backwards_compat() { + let trie = make_test_trie(0).with_eos_token(2); + assert_eq!(trie.eos_token(), 2); + assert_eq!(trie.eos_tokens(), &[2]); + } + + #[test] + fn test_with_info_resets_eos_tokens() { + let trie = make_test_trie(0).with_eos_tokens(&[1, 2]); + let trie2 = trie.with_info(TokRxInfo::new(4, 3)); + assert_eq!(trie2.eos_token(), 3); + assert_eq!(trie2.eos_tokens(), &[3]); + } + + #[test] + fn test_filter_preserves_eos_tokens() { + let trie = make_test_trie(0).with_eos_tokens(&[1, 2]); + let mut filter = trie.alloc_token_set(); + for i in 0..4 { + filter.allow_token(i); + } + let filtered = trie.filter(&filter); + assert_eq!(filtered.eos_tokens(), &[1, 2]); + } + + #[test] + #[should_panic(expected = "eos_tokens must not be empty")] + fn test_with_eos_tokens_empty_panics() { + make_test_trie(0).with_eos_tokens(&[]); + } + + #[test] + fn test_eos_token_set_single() { + let trie = make_test_trie(2); + let set = trie.eos_token_set(); + assert!(set.is_allowed(2)); + assert!(!set.is_allowed(0)); + assert!(!set.is_allowed(1)); + assert_eq!(set.num_set(), 1); + } + + #[test] + fn test_eos_token_set_multiple() { + let trie = make_test_trie(0).with_eos_tokens(&[1, 3]); + let set = trie.eos_token_set(); + assert!(set.is_allowed(1)); + assert!(set.is_allowed(3)); + assert!(!set.is_allowed(0)); + assert!(!set.is_allowed(2)); + assert_eq!(set.num_set(), 2); + } +} diff --git a/toktrie_hf_tokenizers/src/lib.rs b/toktrie_hf_tokenizers/src/lib.rs index f77f19cf..29d126ad 100644 --- a/toktrie_hf_tokenizers/src/lib.rs +++ b/toktrie_hf_tokenizers/src/lib.rs @@ -12,6 +12,7 @@ pub struct ByteTokenizer { pub hf_tokenizer: Tokenizer, info: TokRxInfo, token_bytes: Vec>, + eos_tokens_extra: Vec, } // useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi @@ -148,6 +149,7 @@ impl ByteTokenizer { info: TokRxInfo::new(vocab_size, 0), token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(), hf_tokenizer: hft, + eos_tokens_extra: Vec::new(), }; let mut specials = HashSet::new(); @@ -230,7 +232,32 @@ impl ByteTokenizer { } pub fn set_eos_token(&mut self, tok_id: u32) { + assert!( + tok_id < self.info.vocab_size, + "EOS token ID {tok_id} is out of range (vocab_size={})", + self.info.vocab_size + ); self.info.tok_eos = tok_id; + self.eos_tokens_extra.clear(); + } + + pub fn set_eos_tokens(&mut self, tokens: &[TokenId]) { + assert!(!tokens.is_empty(), "eos_tokens must not be empty"); + for &tok in tokens { + assert!( + tok < self.info.vocab_size, + "EOS token ID {tok} is out of range (vocab_size={})", + self.info.vocab_size + ); + } + self.info.tok_eos = tokens[0]; + self.eos_tokens_extra = tokens[1..].to_vec(); + } + + pub fn eos_tokens(&self) -> Vec { + let mut r = vec![self.info.tok_eos]; + r.extend_from_slice(&self.eos_tokens_extra); + r } pub fn into_tok_env(self, n_vocab: Option) -> Result { @@ -259,7 +286,11 @@ impl ByteTokenizerEnv { } info.vocab_size = n_vocab as u32; } - let tok_trie = TokTrie::from(&info, &token_bytes); + let eos_tokens = tokenizer.eos_tokens(); + let mut tok_trie = TokTrie::from(&info, &token_bytes); + if eos_tokens.len() > 1 { + tok_trie = tok_trie.with_eos_tokens(&eos_tokens); + } Ok(ByteTokenizerEnv { tokenizer, tok_trie, @@ -352,6 +383,7 @@ mod tests { hf_tokenizer, info, token_bytes, + eos_tokens_extra: Vec::new(), }; let env = ByteTokenizerEnv::new(tokenizer, None).unwrap(); let special_id = env.tok_trie().get_special_token("<|end|>").unwrap(); diff --git a/toktrie_tiktoken/src/lib.rs b/toktrie_tiktoken/src/lib.rs index 27d826e9..e76d57cc 100644 --- a/toktrie_tiktoken/src/lib.rs +++ b/toktrie_tiktoken/src/lib.rs @@ -79,6 +79,10 @@ impl TikTokenBPE { *self.tok_trie.info() } + pub fn set_eos_tokens(&mut self, tokens: &[TokenId]) { + self.tok_trie = self.tok_trie.with_eos_tokens(tokens); + } + pub fn to_env(self) -> TokEnv { Arc::new(self) }