Skip to content

Commit

Permalink
【Hackathon 7th No.43】TokenizerFast for Qwen2 (#9532)
Browse files Browse the repository at this point in the history
* add qwen2 tokenizer fast
  • Loading branch information
yinfan98 authored Dec 2, 2024
1 parent 2522bf8 commit 7a221cc
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
("ernie_vil", "ErnieViLTokenizer"),
("glm", "GLMGPT2Tokenizer"),
("qwen", "QWenTokenizer"),
("qwen2", "Qwen2Tokenizer"),
("qwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("yuan", "YuanTokenizer"),
]
)
Expand Down
50 changes: 49 additions & 1 deletion paddlenlp/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,55 @@ def pre_tokenizer(self, replacement, add_prefix_space):
return None


SLOW_TO_FAST_CONVERTERS = {"LlamaTokenizer": LlamaConverter, "BertTokenizer": BertConverter}
class Qwen2Converter(Converter):
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
if not vocab:
vocab = self.original_tokenizer.encoder
if not merges:
merges = list(self.original_tokenizer.bpe_ranks.keys())

tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
unk_token=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
byte_fallback=False,
)
)

tokenizer.normalizer = normalizers.NFC()

tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(
Regex(
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
),
behavior="isolated",
invert=False,
),
pre_tokenizers.ByteLevel(
add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
use_regex=False,
),
]
)

tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

return tokenizer


SLOW_TO_FAST_CONVERTERS = {
"LlamaTokenizer": LlamaConverter,
"BertTokenizer": BertConverter,
"Qwen2Tokenizer": Qwen2Converter,
}


def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .modeling import *
from .modeling_pp import *
from .tokenizer import *
from .tokenizer_fast import *
131 changes: 131 additions & 0 deletions paddlenlp/transformers/qwen2/tokenizer_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Qwen2."""

from typing import Optional, Tuple

from ..tokenizer_utils import AddedToken
from ..tokenizer_utils_fast import PretrainedTokenizerFast
from .tokenizer import Qwen2Tokenizer

VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"tokenizer_file": "tokenizer.json",
}


MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}


class Qwen2TokenizerFast(PretrainedTokenizerFast):
"""
Construct a "fast" Qwen2 tokenizer (backed by PaddleNLP's *tokenizers* library). Based on byte-level
Byte-Pair-Encoding.
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```python
>>> from transformers import Qwen2TokenizerFast
>>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
>>> tokenizer("Hello world")["input_ids"]
[9707, 1879]
>>> tokenizer(" Hello world")["input_ids"]
[21927, 1879]
```
This is expected.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. Not applicable to this tokenizer.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
"""

vocab_files_names = VOCAB_FILES_NAMES
resource_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = Qwen2Tokenizer

def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
**kwargs,
):
# We need to at least pass vocab_file and merges_file to base class
# in case a slow tokenizer needs to be initialized; other can be
# configured through files.
# following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token

bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(pad_token, str)
else pad_token
)

super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
**kwargs,
)

# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
6 changes: 6 additions & 0 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"chat_template_file": CHAT_TEMPLATE_CONFIG_NAME,
}

if hasattr(cls, "vocab_files_names") and len(cls.resource_files_names) == 0:
cls.resource_files_names = copy.deepcopy(cls.vocab_files_names)
logger.error(
"The attribute 'vocab_files_names' is deprecated. Please use 'resource_files_names' instead.",
DeprecationWarning,
)
vocab_files_target = {**cls.resource_files_names, **additional_files_names}
# From HF Hub or AI Studio
if from_hf_hub or from_aistudio:
Expand Down
3 changes: 2 additions & 1 deletion tests/transformers/qwen2/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import os
import unittest

from paddlenlp.transformers import Qwen2Tokenizer
from paddlenlp.transformers import Qwen2Tokenizer, Qwen2TokenizerFast
from paddlenlp.transformers.qwen2.tokenizer import VOCAB_FILES_NAMES, bytes_to_unicode
from tests.transformers.test_tokenizer_common import TokenizerTesterMixin


class Qwen2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "__internal_testing__/tiny-random-qwen2"
tokenizer_class = Qwen2Tokenizer
rust_tokenizer_class = Qwen2TokenizerFast
test_slow_tokenizer = True
space_between_special_tokens = False
from_pretrained_kwargs = None
Expand Down

0 comments on commit 7a221cc

Please sign in to comment.