diff --git a/bindings/node/lib/bindings/decoders.d.ts b/bindings/node/lib/bindings/decoders.d.ts index 9841f8eb4..2b5574243 100644 --- a/bindings/node/lib/bindings/decoders.d.ts +++ b/bindings/node/lib/bindings/decoders.d.ts @@ -49,3 +49,9 @@ export function ctcDecoder( word_delimiter_token?: string, cleanup?: boolean ): Decoder; + +/** + * Instantiate a new Sequence Decoder + * @param [decoders] The decoders to chain + */ +export function sequenceDecoder(decoders: Decoder[]): Decoder; diff --git a/bindings/node/lib/bindings/decoders.js b/bindings/node/lib/bindings/decoders.js index 8789a4484..6de417196 100644 --- a/bindings/node/lib/bindings/decoders.js +++ b/bindings/node/lib/bindings/decoders.js @@ -6,4 +6,5 @@ module.exports = { metaspaceDecoder: native.decoders_Metaspace, bpeDecoder: native.decoders_BPEDecoder, ctcDecoder: native.decoders_CTC, + sequenceDecoder: native.decoders_Sequence, }; diff --git a/bindings/node/lib/bindings/decoders.test.ts b/bindings/node/lib/bindings/decoders.test.ts index b23f243f0..709e0e73d 100644 --- a/bindings/node/lib/bindings/decoders.test.ts +++ b/bindings/node/lib/bindings/decoders.test.ts @@ -1,4 +1,10 @@ -import { bpeDecoder, ctcDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders"; +import { + bpeDecoder, + ctcDecoder, + metaspaceDecoder, + sequenceDecoder, + wordPieceDecoder, +} from "./decoders"; describe("wordPieceDecoder", () => { it("accepts `undefined` as first parameter", () => { @@ -42,3 +48,25 @@ describe("ctcDecoder", () => { ).toEqual("hello"); }); }); + +describe("sequenceDecoder", () => { + it("accepts `empty list` as parameter", () => { + expect(sequenceDecoder([])).toBeDefined(); + }); + it("encodes correctly", () => { + expect( + sequenceDecoder([ctcDecoder(), metaspaceDecoder()]).decode([ + "▁", + "▁", + "H", + "H", + "i", + "i", + "▁", + "y", + "o", + "u", + ]) + ).toEqual("Hi you"); + }); +}); diff --git a/bindings/node/native/Cargo.lock b/bindings/node/native/Cargo.lock index f77a7f2e3..56a4a0b1f 100644 --- a/bindings/node/native/Cargo.lock +++ b/bindings/node/native/Cargo.lock @@ -1719,7 +1719,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokenizers" -version = "0.11.3" +version = "0.12.1" dependencies = [ "aho-corasick", "cached-path", diff --git a/bindings/node/native/src/decoders.rs b/bindings/node/native/src/decoders.rs index 9a01bc4dd..6817da58a 100644 --- a/bindings/node/native/src/decoders.rs +++ b/bindings/node/native/src/decoders.rs @@ -14,11 +14,11 @@ pub struct Decoder { } impl tk::Decoder for Decoder { - fn decode(&self, tokens: Vec) -> tk::Result { + fn decode_chain(&self, tokens: Vec) -> tk::Result> { self.decoder .as_ref() .ok_or("Uninitialized Decoder")? - .decode(tokens) + .decode_chain(tokens) } } @@ -42,6 +42,7 @@ declare_types! { .map_err(|e| Error(format!("{}", e)))?; Ok(cx.string(output).upcast()) + } } } @@ -115,6 +116,33 @@ fn ctc_decoder(mut cx: FunctionContext) -> JsResult { Ok(decoder) } +/// sequence() +fn sequence(mut cx: FunctionContext) -> JsResult { + let decoders = cx.argument::(0)?.to_vec(&mut cx)?; + let mut sequence = Vec::with_capacity(decoders.len()); + + decoders.into_iter().try_for_each(|decoder| { + match decoder.downcast::().or_throw(&mut cx) { + Ok(decoder) => { + let guard = cx.lock(); + if let Some(decoder_arc) = &decoder.borrow(&guard).decoder { + let decoder: DecoderWrapper = (**decoder_arc).clone(); + sequence.push(decoder); + } + Ok(()) + } + Err(e) => Err(e), + } + })?; + + let mut pretok = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; + let guard = cx.lock(); + pretok.borrow_mut(&guard).decoder = Some(Arc::new(tk::DecoderWrapper::Sequence( + tk::decoders::sequence::Sequence::new(sequence), + ))); + Ok(pretok) +} + /// Register everything here pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?; @@ -122,5 +150,6 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_Metaspace", prefix), metaspace)?; m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?; m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?; + m.export_function(&format!("{}_Sequence", prefix), sequence)?; Ok(()) } diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index ff5b74a17..edb427592 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -1751,7 +1751,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokenizers" -version = "0.11.3" +version = "0.12.1" dependencies = [ "aho-corasick", "cached-path", diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.py b/bindings/python/py_src/tokenizers/decoders/__init__.py index ce1af33fe..37514c6a0 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.py +++ b/bindings/python/py_src/tokenizers/decoders/__init__.py @@ -6,3 +6,4 @@ Metaspace = decoders.Metaspace BPEDecoder = decoders.BPEDecoder CTC = decoders.CTC +Sequence = decoders.Sequence diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index 832c9e710..7888e5543 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -126,6 +126,30 @@ class Metaspace(Decoder): """ pass +class Sequence(Decoder): + """ + Sequence Decoder + + Args: + decoders (:obj:`List[Decoder]`) + The decoders that need to be chained + """ + + def __init__(self, decoders): + pass + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + class WordPiece(Decoder): """ WordPiece Decoder diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index b793f3a0f..396843024 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -10,6 +10,7 @@ use tk::decoders::bpe::BPEDecoder; use tk::decoders::byte_level::ByteLevel; use tk::decoders::ctc::CTC; use tk::decoders::metaspace::Metaspace; +use tk::decoders::sequence::Sequence; use tk::decoders::wordpiece::WordPiece; use tk::decoders::DecoderWrapper; use tk::Decoder; @@ -45,14 +46,17 @@ impl PyDecoder { DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py), DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py), DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py), + DecoderWrapper::Sequence(_) => { + Py::new(py, (PySequenceDecoder {}, base))?.into_py(py) + } }, }) } } impl Decoder for PyDecoder { - fn decode(&self, tokens: Vec) -> tk::Result { - self.decoder.decode(tokens) + fn decode_chain(&self, tokens: Vec) -> tk::Result> { + self.decoder.decode_chain(tokens) } } @@ -325,6 +329,36 @@ impl PyCTCDecoder { } } +/// Sequence Decoder +/// +/// Args: +/// decoders (:obj:`List[Decoder]`) +/// The decoders that need to be chained +#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name="Sequence")] +#[pyo3(text_signature = "(self, decoders)")] +pub struct PySequenceDecoder {} +#[pymethods] +impl PySequenceDecoder { + #[new] + #[args(decoders)] + fn new(decoders_py: &PyList) -> PyResult<(Self, PyDecoder)> { + let mut decoders: Vec = Vec::with_capacity(decoders_py.len()); + for decoder_py in decoders_py.iter() { + let decoder: PyRef = decoder_py.extract()?; + let decoder = match &decoder.decoder { + PyDecoderWrapper::Wrapped(inner) => inner, + PyDecoderWrapper::Custom(_) => unimplemented!(), + }; + decoders.push(decoder.read().unwrap().clone()); + } + Ok((PySequenceDecoder {}, Sequence::new(decoders).into())) + } + + fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple { + PyTuple::new(py, &[PyList::empty(py)]) + } +} + #[derive(Clone)] pub(crate) struct CustomDecoder { inner: PyObject, @@ -342,7 +376,17 @@ impl Decoder for CustomDecoder { let decoded = self .inner .call_method(py, "decode", (tokens,), None)? - .extract::(py)?; + .extract(py)?; + Ok(decoded) + }) + } + + fn decode_chain(&self, tokens: Vec) -> tk::Result> { + Python::with_gil(|py| { + let decoded = self + .inner + .call_method(py, "decode_chain", (tokens,), None)? + .extract(py)?; Ok(decoded) }) } @@ -396,10 +440,10 @@ where } impl Decoder for PyDecoderWrapper { - fn decode(&self, tokens: Vec) -> tk::Result { + fn decode_chain(&self, tokens: Vec) -> tk::Result> { match self { - PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens), - PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens), + PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens), + PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens), } } } diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index b1580b0e4..42dd6b7c1 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -90,6 +90,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/tests/bindings/test_decoders.py b/bindings/python/tests/bindings/test_decoders.py index 41e7187e7..bc4554147 100644 --- a/bindings/python/tests/bindings/test_decoders.py +++ b/bindings/python/tests/bindings/test_decoders.py @@ -2,7 +2,7 @@ import pickle import json -from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC +from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC, Sequence class TestByteLevel: @@ -150,3 +150,19 @@ def test_can_modify(self): decoder.cleanup = False assert decoder.cleanup == False + + +class TestSequenceDecoder: + def test_instantiate(self): + assert Sequence([]) is not None + assert Sequence([CTC()]) is not None + assert isinstance(Sequence([]), Decoder) + assert isinstance(Sequence([]), Sequence) + serialized = pickle.dumps(Sequence([])) + assert isinstance(pickle.loads(serialized), Sequence) + + def test_decoding(self): + decoder = Sequence([CTC(), Metaspace()]) + initial = ["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"] + expected = "Hi you" + assert decoder.decode(initial) == expected diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 1d4115cbb..813dc7083 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -24,7 +24,15 @@ impl Default for BPEDecoder { } impl Decoder for BPEDecoder { - fn decode(&self, tokens: Vec) -> Result { - Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned()) + fn decode_chain(&self, tokens: Vec) -> Result> { + let n = tokens.len() - 1; + Ok(tokens + .into_iter() + .enumerate() + .map(|(i, token)| { + let replacement = if i == n { "" } else { " " }; + token.replace(&self.suffix, replacement) + }) + .collect()) } } diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index 17f7ba16e..83ec261e9 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -42,16 +42,23 @@ impl Default for CTC { } impl Decoder for CTC { - fn decode(&self, tokens: Vec) -> Result { - let mut output = tokens + fn decode_chain(&self, tokens: Vec) -> Result> { + Ok(tokens .into_iter() .dedup() - .join("") - .replace(&self.pad_token, ""); - if self.cleanup { - output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " "); - } - Ok(output) + .filter_map(|token| { + let mut replaced = token.replace(&self.pad_token, ""); + if self.cleanup { + replaced = + wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " "); + } + if replaced.is_empty() { + None + } else { + Some(replaced) + } + }) + .collect()) } } @@ -66,8 +73,8 @@ mod tests { .map(|s| s.to_string()) .collect(); assert_eq!( - ctc_decoder.decode(id_to_string_result).unwrap(), - "hello".to_string() + ctc_decoder.decode_chain(id_to_string_result).unwrap(), + vec!["h", "e", "l", "l", "o"] ); } #[test] @@ -78,8 +85,8 @@ mod tests { .map(|s| s.to_string()) .collect(); assert_eq!( - ctc_decoder.decode(id_to_string_result).unwrap(), - "hello world".to_string() + ctc_decoder.decode_chain(id_to_string_result).unwrap(), + vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"] ); } #[test] @@ -87,8 +94,12 @@ mod tests { let ctc_decoder = CTC::default(); let id_to_string_result = " A | | M A N | | | S A I D D | | T T O | | T H E E | | | U U N N I V E R R S E E | | S S I R R | | | I | E X I S T | | ".split(' ').map(|s| s.to_string()).collect(); assert_eq!( - ctc_decoder.decode(id_to_string_result).unwrap(), - "A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string() + ctc_decoder.decode_chain(id_to_string_result).unwrap(), + vec![ + "A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H", + "E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I", + " ", "E", "X", "I", "S", "T", " " + ] ); } #[test] @@ -96,8 +107,14 @@ mod tests { let ctc_decoder = CTC::default(); let id_to_string_result = " H I S S | | I N S T T A N C C T | | | | | P A N N N I C | | W A S | | F O L L L O O W E E D | | B Y | | | A | | S S S M M A L L L | | | S H H A R R P | B L L O W W | | | H I G H H | | O N | | H I S S | | C H H E S S T T | | | ".split(' ').map(|s| s.to_string()).collect(); assert_eq!( - ctc_decoder.decode(id_to_string_result).unwrap(), - "HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string() + ctc_decoder.decode_chain(id_to_string_result).unwrap(), + vec![ + "H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N", + "I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ", + "B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P", + " ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I", + "S", " ", "C", "H", "E", "S", "T", " " + ] ); } } diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index a571ef5bb..c5122c871 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -1,5 +1,6 @@ pub mod bpe; pub mod ctc; +pub mod sequence; pub mod wordpiece; // Re-export these as decoders @@ -10,6 +11,7 @@ use serde::{Deserialize, Serialize}; use crate::decoders::bpe::BPEDecoder; use crate::decoders::ctc::CTC; +use crate::decoders::sequence::Sequence; use crate::decoders::wordpiece::WordPiece; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; @@ -23,16 +25,18 @@ pub enum DecoderWrapper { WordPiece(WordPiece), Metaspace(Metaspace), CTC(CTC), + Sequence(Sequence), } impl Decoder for DecoderWrapper { - fn decode(&self, tokens: Vec) -> Result { + fn decode_chain(&self, tokens: Vec) -> Result> { match self { - Self::BPE(bpe) => bpe.decode(tokens), - Self::ByteLevel(bl) => bl.decode(tokens), - Self::Metaspace(ms) => ms.decode(tokens), - Self::WordPiece(wp) => wp.decode(tokens), - Self::CTC(ctc) => ctc.decode(tokens), + Self::BPE(bpe) => bpe.decode_chain(tokens), + Self::ByteLevel(bl) => bl.decode_chain(tokens), + Self::Metaspace(ms) => ms.decode_chain(tokens), + Self::WordPiece(wp) => wp.decode_chain(tokens), + Self::CTC(ctc) => ctc.decode_chain(tokens), + Self::Sequence(seq) => seq.decode_chain(tokens), } } } @@ -42,3 +46,4 @@ impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel); impl_enum_from!(Metaspace, DecoderWrapper, Metaspace); impl_enum_from!(WordPiece, DecoderWrapper, WordPiece); impl_enum_from!(CTC, DecoderWrapper, CTC); +impl_enum_from!(Sequence, DecoderWrapper, Sequence); diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs new file mode 100644 index 000000000..484df6c95 --- /dev/null +++ b/tokenizers/src/decoders/sequence.rs @@ -0,0 +1,47 @@ +use crate::decoders::DecoderWrapper; +use crate::tokenizer::{Decoder, Result}; +use crate::utils::macro_rules_attribute; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug)] +#[macro_rules_attribute(impl_serde_type!)] +pub struct Sequence { + decoders: Vec, +} + +impl Sequence { + pub fn new(decoders: Vec) -> Self { + Self { decoders } + } +} + +impl Decoder for Sequence { + fn decode_chain(&self, mut tokens: Vec) -> Result> { + for decoder in &self.decoders { + tokens = decoder.decode_chain(tokens)?; + } + Ok(tokens) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::decoders::ctc::CTC; + use crate::pre_tokenizers::metaspace::Metaspace; + + #[test] + fn sequence_basic() { + let decoders = vec![ + DecoderWrapper::CTC(CTC::default()), + DecoderWrapper::Metaspace(Metaspace::default()), + ]; + let decoder = Sequence::new(decoders); + let tokens: Vec = vec!["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"] + .into_iter() + .map(|s| s.to_string()) + .collect(); + let out_tokens = decoder.decode(tokens).unwrap(); + assert_eq!(out_tokens, "Hi you"); + } +} diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index c9b92d925..952108d65 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -28,7 +28,7 @@ impl Default for WordPiece { } } } -pub fn cleanup(dirty_input: String) -> String { +pub fn cleanup(dirty_input: &str) -> String { dirty_input .replace(" .", ".") .replace(" ?", "?") @@ -44,12 +44,21 @@ pub fn cleanup(dirty_input: String) -> String { } impl Decoder for WordPiece { - fn decode(&self, tokens: Vec) -> Result { - let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), ""); - if self.cleanup { - output = cleanup(output); - } - - Ok(output) + fn decode_chain(&self, mut tokens: Vec) -> Result> { + tokens + .iter_mut() + .enumerate() + .map(|(i, token)| { + if token.starts_with(&self.prefix) { + *token = token.replacen(&self.prefix, "", 1); + } else if i != 0 { + *token = format!(" {}", token); + } + if self.cleanup { + *token = cleanup(token); + } + Ok(token.to_string()) + }) + .collect::>() } } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8f3d5fa8f..e58f1c23c 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -145,8 +145,11 @@ impl PreTokenizer for ByteLevel { /// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their /// unicode counterpart, before merging everything back into a single String. +/// This decoder will consume the tokens and merge them in one step to alleviate +/// the fact that single token decoded might be a byte not representable as +/// as String. impl Decoder for ByteLevel { - fn decode(&self, tokens: Vec) -> Result { + fn decode_chain(&self, tokens: Vec) -> Result> { let toks = tokens .into_iter() .flat_map(|t| { @@ -159,8 +162,8 @@ impl Decoder for ByteLevel { }) .unwrap_or_else(|| t.as_bytes().to_vec()) }) - .collect::>(); - Ok(String::from_utf8_lossy(&toks).into_owned()) + .collect::>(); + Ok(vec![String::from_utf8_lossy(&toks).to_string()]) } } @@ -284,9 +287,8 @@ mod tests { fn decoding() { let bytelevel = ByteLevel::default().add_prefix_space(false); assert_eq!( - "Hello my friend, how is your day going?", bytelevel - .decode( + .decode_chain( vec![ "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" @@ -295,7 +297,8 @@ mod tests { .map(|s| s.into()) .collect::>() ) - .unwrap() + .unwrap(), + vec!["Hello my friend, how is your day going?"] ); } @@ -347,7 +350,10 @@ mod tests { .iter() .flat_map(|(s, _, _)| s.split("").map(|t| t.into())) .collect::>(); - assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap()); + assert_eq!( + sample, + bytelevel.decode_chain(separated_tokens).unwrap().join("") + ); } } @@ -534,7 +540,7 @@ mod tests { let byte_level = ByteLevel::default(); assert_eq!( byte_level - .decode(vec![ + .decode_chain(vec![ "Hello".into(), "Ġthere".into(), "Ġdear".into(), @@ -543,7 +549,7 @@ mod tests { "[PA D]".into() ]) .unwrap(), - "Hello there dear friend! [PA D]" + vec!["Hello there dear friend! [PA D]"] ); } diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 6df63df36..07472b049 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -77,23 +77,27 @@ impl PreTokenizer for Metaspace { } impl Decoder for Metaspace { - fn decode(&self, tokens: Vec) -> Result { + fn decode_chain(&self, tokens: Vec) -> Result> { Ok(tokens .iter() - .flat_map(|t| t.chars()) .enumerate() - .filter_map(|(i, c)| { - if c == self.replacement { - if i == 0 && self.add_prefix_space { - None - } else { - Some(' ') - } - } else { - Some(c) - } + .map(|(i, token)| { + token + .chars() + .flat_map(|c| { + if c == self.replacement { + if i == 0 && self.add_prefix_space { + None + } else { + Some(' ') + } + } else { + Some(c) + } + }) + .collect::() }) - .collect::()) + .collect()) } } @@ -188,8 +192,8 @@ mod tests { fn decode() { let decoder = Metaspace::new('▁', true); let res = decoder - .decode(vec!["▁Hey".into(), "▁friend!".into()]) + .decode_chain(vec!["▁Hey".into(), "▁friend!".into()]) .unwrap(); - assert_eq!(&res, "Hey friend!") + assert_eq!(res, vec!["Hey", " friend!"]) } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index a76725090..2595bfd9c 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -119,9 +119,13 @@ impl dyn PostProcessor { } } -/// A `Decoder` has the responsibility to merge the given `Vec` in a `String`. +/// A `Decoder` changes the raw tokens into its more readable form. pub trait Decoder { - fn decode(&self, tokens: Vec) -> Result; + fn decode(&self, tokens: Vec) -> Result { + let results = self.decode_chain(tokens)?; + Ok(results.join("")) + } + fn decode_chain(&self, tokens: Vec) -> Result>; } /// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences