Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
740d523
Add support for __all__ and potentailly deleting functions
ArthurZucker Oct 1, 2024
56d6056
updates
ArthurZucker Oct 1, 2024
240b127
update
ArthurZucker Oct 1, 2024
1d629b9
nits
ArthurZucker Oct 1, 2024
3d1fc14
remove dummies
ArthurZucker Oct 1, 2024
29fff49
fix warning
ArthurZucker Oct 1, 2024
5441aad
fixup
ArthurZucker Oct 1, 2024
ddcb736
style
ArthurZucker Oct 1, 2024
b774628
update
ArthurZucker Oct 1, 2024
0eee3b5
fixup
ArthurZucker Oct 1, 2024
47ea51c
skip copied from when # skip
ArthurZucker Oct 1, 2024
dbf3bd7
remove log
ArthurZucker Oct 1, 2024
ba207d0
bring dummies back
ArthurZucker Oct 1, 2024
584a9be
fixup
ArthurZucker Oct 1, 2024
69cc01c
remove copied from
ArthurZucker Oct 1, 2024
c2961ff
fixup
ArthurZucker Oct 1, 2024
5f95f93
remove warnings from `make fix-copies`
ArthurZucker Oct 1, 2024
f80c2e7
fix doc issues
ArthurZucker Oct 2, 2024
584296f
nits
ArthurZucker Oct 2, 2024
7e9269a
Better error message !
ArthurZucker Oct 2, 2024
73be261
add support for more flexible naming!
ArthurZucker Oct 2, 2024
84055a3
style
ArthurZucker Oct 2, 2024
93c1c7e
breaking style?
ArthurZucker Oct 2, 2024
de9ad8f
fix super() renaming issues
ArthurZucker Oct 8, 2024
c53fc87
del not needed when you don't call super().__init__()
ArthurZucker Oct 8, 2024
d970d11
style
ArthurZucker Oct 8, 2024
6e9a325
no more fmt on :)
ArthurZucker Oct 8, 2024
528b4f3
properly remove `self`
ArthurZucker Oct 8, 2024
885d577
Merge branch 'main' of github.com:huggingface/transformers into add-a…
ArthurZucker Oct 8, 2024
f108d98
fixup
ArthurZucker Oct 8, 2024
3c69624
fix
ArthurZucker Oct 8, 2024
426f315
doc nits
ArthurZucker Oct 8, 2024
db7b32e
add some doc 🫡
ArthurZucker Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

import math
from typing import List, Optional, Tuple, Union

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ def __init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


__all__ = ["GemmaConfig"]
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,3 +1395,6 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]
181 changes: 180 additions & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import sentencepiece as spm
import torch
import torch.utils.checkpoint
from torch import nn
Expand All @@ -27,6 +28,7 @@
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import is_torchdynamo_compiling, logging
from ..llama.modeling_llama import (
LlamaDecoderLayer,
Expand All @@ -38,6 +40,15 @@
apply_rotary_pos_emb,
repeat_kv,
)
from ..llama.tokenization_llama import LlamaTokenizer


if TYPE_CHECKING:
from ...tokenization_utils_base import TextInput

VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}

SPIECE_UNDERLINE = "▁"


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -164,6 +175,164 @@ def __init__(
)


class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
"""
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.

Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:

- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.

- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.

add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Gemma should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
"""

def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token

self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)

PreTrainedTokenizer.__init__(
self,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
del self.add_prefix_space
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated
del self.legacy

def get_spm_processor(self):
raise AttributeError("Not needed for Gemma")

def unk_token_length(self):
raise AttributeError("Not needed for Gemma")

# skip
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Args:
text: TextInput
Simply calls PreTrainedTokenizer's method
"""
return PreTrainedTokenizer.tokenize(self, text, **kwargs)

# skip
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
"""
return self.sp_model.encode(text, out_type=str)

def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
sub_texts = []
current_sub_text = []
for ids in token_ids:
if skip_special_tokens and ids in self.all_special_ids:
continue
if ids in self._added_tokens_decoder:
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
sub_texts.append(self._added_tokens_decoder[ids].content)
current_sub_text = []
else:
current_sub_text.append(ids)
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))

if spaces_between_special_tokens:
sub_texts = " ".join(sub_texts)
else:
sub_texts = "".join(sub_texts)

return sub_texts.replace(SPIECE_UNDERLINE, " ")

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self._added_tokens_encoder:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string


class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -881,3 +1050,13 @@ def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.post_init()


__all__ = [
"GemmaConfig",
"GemmaTokenizer",
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
]
Loading