diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index 488dc7708..c9876092e 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -4,7 +4,7 @@ class DecodeStream: Class needed for streaming decode """ - def __init__(self, skip_special_tokens): + def __init__(self, ids=None, skip_special_tokens=False): pass class Decoder: diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index f09ac592f..4b9367ef4 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -646,21 +646,44 @@ pub struct PyDecodeStream { prefix_index: usize, } +#[derive(Clone)] +enum StreamInput { + Id(u32), + Ids(Vec), +} + +impl FromPyObject<'_> for StreamInput { + fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult { + if let Ok(id) = obj.extract::() { + Ok(StreamInput::Id(id)) + } else if let Ok(ids) = obj.extract::>() { + Ok(StreamInput::Ids(ids)) + } else { + Err(PyErr::new::( + "StreamInput must be either an integer or a list of integers", + )) + } + } +} + #[pymethods] impl PyDecodeStream { #[new] - #[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")] - fn new(skip_special_tokens: bool) -> Self { + #[pyo3(signature = (ids=None, skip_special_tokens=false), text_signature = "(self, ids=None, skip_special_tokens=False)")] + fn new(ids: Option>, skip_special_tokens: Option) -> Self { PyDecodeStream { - skip_special_tokens, - ids: vec![], - prefix: "".to_string(), + skip_special_tokens: skip_special_tokens.unwrap_or(false), + ids: ids.unwrap_or_default(), + prefix: String::new(), prefix_index: 0, } } - #[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")] - fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult> { + fn step(&mut self, tokenizer: &PyTokenizer, id: StreamInput) -> PyResult> { + let id: Vec = match id { + StreamInput::Id(id) => vec![id], + StreamInput::Ids(ids) => ids, + }; ToPyResult(tk::tokenizer::step_decode_stream( &tokenizer.tokenizer, id, diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 9aea79477..03fa6bdf7 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -341,7 +341,7 @@ impl PyBertProcessing { } #[getter] - fn get_sep(self_: PyRef) -> Result, PyErr> { + fn get_sep(self_: PyRef<'_, Self>) -> Result, PyErr> { let py = self_.py(); let (tok, id) = getter!(self_, Bert, get_sep_copy()); PyTuple::new( @@ -358,7 +358,7 @@ impl PyBertProcessing { } #[getter] - fn get_cls(self_: PyRef) -> Result, PyErr> { + fn get_cls(self_: PyRef<'_, Self>) -> Result, PyErr> { let py = self_.py(); let (tok, id) = getter!(self_, Bert, get_cls_copy()); PyTuple::new( @@ -422,7 +422,7 @@ impl PyRobertaProcessing { } #[getter] - fn get_sep(self_: PyRef) -> Result, PyErr> { + fn get_sep(self_: PyRef<'_, Self>) -> Result, PyErr> { let py = self_.py(); let (tok, id) = getter!(self_, Roberta, get_sep_copy()); PyTuple::new( @@ -439,7 +439,7 @@ impl PyRobertaProcessing { } #[getter] - fn get_cls(self_: PyRef) -> Result, PyErr> { + fn get_cls(self_: PyRef<'_, Self>) -> Result, PyErr> { let py = self_.py(); let (tok, id) = getter!(self_, Roberta, get_cls_copy()); PyTuple::new( diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index d50f283e7..98ce3ac5e 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -371,6 +371,110 @@ def test_decode(self): assert stream.step(tokenizer, 2) == " is" assert stream.step(tokenizer, 3) == " john" + stream = DecodeStream(ids=[0, 1, 2]) + assert stream.step(tokenizer, 3) == " john" + + def test_decode_stream_fallback(self): + tokenizer = Tokenizer.from_pretrained("gpt2") + # tokenizer.decode([255]) fails because its a fallback + # tokenizer.encode("อั").ids = [19567, 255, 19567, 109] + stream = DecodeStream() + stream.step(tokenizer, [19567]) + stream.step(tokenizer, [255]) + stream.step(tokenizer, [19567]) + out = stream.step(tokenizer, [109]) + assert out == "ั" + + stream = DecodeStream() + out = stream.step(tokenizer, [19567, 255, 19567, 109]) + assert out == "อั" + stream = DecodeStream() + stream.step(tokenizer, [19567]) + out = stream.step(tokenizer, [255, 19567, 109]) + assert out == "อั" + + stream = DecodeStream() + stream.step(tokenizer, [19567]) + first_out = stream.step(tokenizer, [255]) + assert first_out == "อ" + # since we emitted the 'อ', we can't produce 'อั' + out = stream.step(tokenizer, [19567, 109]) + assert out == "ั" + + stream = DecodeStream([19567, 255, 19567]) + # the stream's prefix is 'อ�' which is invalid, thus all ids are kept for the next step + out = stream.step(tokenizer, [109]) + assert out == "อั" + + def test_decode_skip_special_tokens(self): + tokenizer = Tokenizer.from_pretrained("hf-internal-testing/Llama-3.1-8B-Instruct") + + stream = DecodeStream([40]) + out = stream.step(tokenizer, [2846, 40, 40, 40]) + assert out == "'mIII" + + stream = DecodeStream( + [ + 128000, + 128006, + 9125, + 128007, + 271, + 38766, + 1303, + 33025, + 2696, + 25, + 6790, + 220, + 2366, + 18, + 198, + 15724, + 2696, + 25, + 220, + 1627, + 10263, + 220, + 2366, + 19, + 271, + 9514, + 527, + 264, + 11190, + 18328, + 13, + 128009, + 128006, + 882, + 128007, + 271, + 15339, + 11, + 1268, + 527, + 499, + 30, + 128009, + 128006, + 78191, + 128007, + 271, + ] + ) + out = stream.step(tokenizer, 40) + assert out == "I" + + stream = DecodeStream([40]) + out = stream.step(tokenizer, 2846) + assert out == "'m" + + stream = DecodeStream([40]) + out = stream.step(tokenizer, [2846, 40, 40, 40]) + assert out == "'mIII" + def test_decode_stream(self): vocab = [ ("", 0.0), diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 267cf8b57..955ac7b5c 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -356,7 +356,7 @@ impl Unigram { } /// Iterate of vocabulary of the model as a pair of `(token, score)`. - pub fn iter(&self) -> UnigramIterator { + pub fn iter(&self) -> UnigramIterator<'_> { UnigramIterator { model: self, i: 0 } } diff --git a/tokenizers/src/models/unigram/trie.rs b/tokenizers/src/models/unigram/trie.rs index dd06f7f02..7c7149d00 100644 --- a/tokenizers/src/models/unigram/trie.rs +++ b/tokenizers/src/models/unigram/trie.rs @@ -30,7 +30,7 @@ impl Trie