diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index a5a7ec377..2ac1e49ec 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -685,6 +685,13 @@ impl PyDecodeStream { )) .into() } + fn __copy__(&self) -> Self { + self.clone() + } + + fn __deepcopy__(&self, _memo: &Bound<'_, PyDict>) -> Self { + self.clone() + } } #[cfg(test)] diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index b28104e91..96b75b24d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -1,4 +1,5 @@ import pickle +import copy import concurrent.futures import pytest import numpy as np @@ -374,6 +375,27 @@ def test_decode(self): stream = DecodeStream(ids=[0, 1, 2]) assert stream.step(tokenizer, 3) == " john" + def test_decode_stream_copy_and_prefix_ids(self): + tokenizer = Tokenizer(BPE()) + tokenizer.add_tokens(["my", "name", "is", "john"]) + token_ids = [0, 1, 2, 3] + + stream = DecodeStream(skip_special_tokens=False) + assert stream.step(tokenizer, token_ids[0]) == "my" + assert stream.step(tokenizer, token_ids[1]) == " name" + stream_copy = copy.copy(stream) + assert stream.step(tokenizer, token_ids[2]) == " is" + assert stream_copy.step(tokenizer, token_ids[2]) == " is" + assert stream.step(tokenizer, token_ids[3]) == " john" + assert stream_copy.step(tokenizer, token_ids[3]) == " john" + + stream_steps = DecodeStream([]) + last_chunk = None + for tid in token_ids: + last_chunk = stream_steps.step(tokenizer, tid) + stream_prefill = DecodeStream(token_ids[:-1]) + assert stream_prefill.step(tokenizer, token_ids[-1]) == last_chunk + def test_decode_stream_fallback(self): tokenizer = Tokenizer.from_pretrained("gpt2") # tokenizer.decode([255]) fails because its a fallback