Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tokenizer] Fix tokenizer of llama3.3 #9641

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,8 @@
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min)
expanded_attn_mask = expanded_attn_mask.astype(dtype)

Check warning on line 1605 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1604-L1605

Added lines #L1604 - L1605 were not covered by tests
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里PaddleNLP-CI已经验证

return expanded_attn_mask

@paddle.jit.not_to_static
Expand Down
6 changes: 4 additions & 2 deletions paddlenlp/transformers/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,11 @@ def __init__(
self.eos_token = ENDOFTEXT
self.bos_token_id = self.bod_id
self.eos_token_id = self.eod_id
self.pad_token = self.convert_ids_to_tokens(self.eos_token_id)
if "pad_token" not in kwargs:
self.pad_token = self.convert_ids_to_tokens(self.eos_token_id)
kwargs["pad_token"] = self.pad_token

super().__init__(pad_token=self.pad_token, **kwargs)
super().__init__(**kwargs)

def __len__(self) -> int:
return self.tokenizer.n_vocab
Expand Down
26 changes: 26 additions & 0 deletions tests/transformers/llama/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tempfile
import unittest

from parameterized import parameterized_class

from paddlenlp.transformers.auto.tokenizer import AutoTokenizer
from paddlenlp.transformers.llama.tokenizer import LlamaTokenizer
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
Expand Down Expand Up @@ -213,6 +215,30 @@ def test_pretrained_model_lists(self):
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_resource_files_map.values())[0]), 1)


@parameterized_class(
["model_name_or_path"],
[
["facebook/llama-7b"],
["meta-llama/Meta-Llama-3.1-8B"],
["meta-llama/Llama-3.2-1B"],
["meta-llama/Llama-3.3-70B-Instruct"],
],
)
class LlamaTokenizationLoadTest(unittest.TestCase):
model_name_or_path: str = None

def get_tokenizer(self, **kwargs) -> PretrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, **kwargs)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
return tokenizer

def test_load_tokenizer(self):
tokenizer = self.get_tokenizer()
text = "lower newer"
tokenizer.tokenize(text, add_prefix_space=True)


class TikTokenIntegrationTests(unittest.TestCase):
"""
A class that regroups important test to make sure that we properly handle the special tokens.
Expand Down