Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bindings/node/lib/bindings/decoders.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ export function ctcDecoder(
word_delimiter_token?: string,
cleanup?: boolean
): Decoder;

/**
* Instantiate a new Sequence Decoder
* @param [decoders] The decoders to chain
*/
export function sequenceDecoder(decoders: Decoder[]): Decoder;
1 change: 1 addition & 0 deletions bindings/node/lib/bindings/decoders.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ module.exports = {
metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC,
sequenceDecoder: native.decoders_Sequence,
};
30 changes: 29 additions & 1 deletion bindings/node/lib/bindings/decoders.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { bpeDecoder, ctcDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
import {
bpeDecoder,
ctcDecoder,
metaspaceDecoder,
sequenceDecoder,
wordPieceDecoder,
} from "./decoders";

describe("wordPieceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
Expand Down Expand Up @@ -42,3 +48,25 @@ describe("ctcDecoder", () => {
).toEqual("hello");
});
});

describe("sequenceDecoder", () => {
it("accepts `empty list` as parameter", () => {
expect(sequenceDecoder([])).toBeDefined();
});
it("encodes correctly", () => {
expect(
sequenceDecoder([ctcDecoder(), metaspaceDecoder()]).decode([
"▁",
"▁",
"H",
"H",
"i",
"i",
"▁",
"y",
"o",
"u",
])
).toEqual("Hi you");
});
});
2 changes: 1 addition & 1 deletion bindings/node/native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 31 additions & 2 deletions bindings/node/native/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ pub struct Decoder {
}

impl tk::Decoder for Decoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
self.decoder
.as_ref()
.ok_or("Uninitialized Decoder")?
.decode(tokens)
.decode_chain(tokens)
}
}

Expand All @@ -42,6 +42,7 @@ declare_types! {
.map_err(|e| Error(format!("{}", e)))?;

Ok(cx.string(output).upcast())

}
}
}
Expand Down Expand Up @@ -115,12 +116,40 @@ fn ctc_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}

/// sequence()
fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let decoders = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let mut sequence = Vec::with_capacity(decoders.len());

decoders.into_iter().try_for_each(|decoder| {
match decoder.downcast::<JsDecoder>().or_throw(&mut cx) {
Ok(decoder) => {
let guard = cx.lock();
if let Some(decoder_arc) = &decoder.borrow(&guard).decoder {
let decoder: DecoderWrapper = (**decoder_arc).clone();
sequence.push(decoder);
}
Ok(())
}
Err(e) => Err(e),
}
})?;

let mut pretok = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
pretok.borrow_mut(&guard).decoder = Some(Arc::new(tk::DecoderWrapper::Sequence(
tk::decoders::sequence::Sequence::new(sequence),
)));
Ok(pretok)
}

/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
Ok(())
}
2 changes: 1 addition & 1 deletion bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC
Sequence = decoders.Sequence
24 changes: 24 additions & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,30 @@ class Metaspace(Decoder):
"""
pass

class Sequence(Decoder):
"""
Sequence Decoder

Args:
decoders (:obj:`List[Decoder]`)
The decoders that need to be chained
"""

def __init__(self, decoders):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string

Args:
tokens (:obj:`List[str]`):
The list of tokens to decode

Returns:
:obj:`str`: The decoded string
"""
pass

class WordPiece(Decoder):
"""
WordPiece Decoder
Expand Down
56 changes: 50 additions & 6 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::sequence::Sequence;
use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper;
use tk::Decoder;
Expand Down Expand Up @@ -45,14 +46,17 @@ impl PyDecoder {
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
DecoderWrapper::Sequence(_) => {
Py::new(py, (PySequenceDecoder {}, base))?.into_py(py)
}
},
})
}
}

impl Decoder for PyDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
self.decoder.decode(tokens)
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
self.decoder.decode_chain(tokens)
}
}

Expand Down Expand Up @@ -325,6 +329,36 @@ impl PyCTCDecoder {
}
}

/// Sequence Decoder
///
/// Args:
/// decoders (:obj:`List[Decoder]`)
/// The decoders that need to be chained
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name="Sequence")]
#[pyo3(text_signature = "(self, decoders)")]
pub struct PySequenceDecoder {}
#[pymethods]
impl PySequenceDecoder {
#[new]
#[args(decoders)]
fn new(decoders_py: &PyList) -> PyResult<(Self, PyDecoder)> {
let mut decoders: Vec<DecoderWrapper> = Vec::with_capacity(decoders_py.len());
for decoder_py in decoders_py.iter() {
let decoder: PyRef<PyDecoder> = decoder_py.extract()?;
let decoder = match &decoder.decoder {
PyDecoderWrapper::Wrapped(inner) => inner,
PyDecoderWrapper::Custom(_) => unimplemented!(),
};
decoders.push(decoder.read().unwrap().clone());
}
Ok((PySequenceDecoder {}, Sequence::new(decoders).into()))
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
PyTuple::new(py, &[PyList::empty(py)])
}
}

#[derive(Clone)]
pub(crate) struct CustomDecoder {
inner: PyObject,
Expand All @@ -342,7 +376,17 @@ impl Decoder for CustomDecoder {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract::<String>(py)?;
.extract(py)?;
Ok(decoded)
})
}

fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode_chain", (tokens,), None)?
.extract(py)?;
Ok(decoded)
})
}
Expand Down Expand Up @@ -396,10 +440,10 @@ where
}

impl Decoder for PyDecoderWrapper {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<decoders::PyMetaspaceDec>()?;
m.add_class::<decoders::PyBPEDecoder>()?;
m.add_class::<decoders::PyCTCDecoder>()?;
m.add_class::<decoders::PySequenceDecoder>()?;
Ok(())
}

Expand Down
18 changes: 17 additions & 1 deletion bindings/python/tests/bindings/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import json

from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC, Sequence


class TestByteLevel:
Expand Down Expand Up @@ -150,3 +150,19 @@ def test_can_modify(self):

decoder.cleanup = False
assert decoder.cleanup == False


class TestSequenceDecoder:
def test_instantiate(self):
assert Sequence([]) is not None
assert Sequence([CTC()]) is not None
assert isinstance(Sequence([]), Decoder)
assert isinstance(Sequence([]), Sequence)
serialized = pickle.dumps(Sequence([]))
assert isinstance(pickle.loads(serialized), Sequence)

def test_decoding(self):
decoder = Sequence([CTC(), Metaspace()])
initial = ["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"]
expected = "Hi you"
assert decoder.decode(initial) == expected
12 changes: 10 additions & 2 deletions tokenizers/src/decoders/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ impl Default for BPEDecoder {
}

impl Decoder for BPEDecoder {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned())
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
let n = tokens.len() - 1;
Ok(tokens
.into_iter()
.enumerate()
.map(|(i, token)| {
let replacement = if i == n { "" } else { " " };
token.replace(&self.suffix, replacement)
})
.collect())
}
}
Loading