Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions src/transformers/models/prophetnet/configuration_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" ProphetNet model configuration"""

from typing import Callable, Optional, Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging
Expand Down Expand Up @@ -105,32 +106,32 @@ class ProphetNetConfig(PretrainedConfig):

def __init__(
self,
activation_dropout=0.1,
activation_function="gelu",
vocab_size=30522,
hidden_size=1024,
encoder_ffn_dim=4096,
num_encoder_layers=12,
num_encoder_attention_heads=16,
decoder_ffn_dim=4096,
num_decoder_layers=12,
num_decoder_attention_heads=16,
attention_dropout=0.1,
dropout=0.1,
max_position_embeddings=512,
init_std=0.02,
is_encoder_decoder=True,
add_cross_attention=True,
decoder_start_token_id=0,
ngram=2,
num_buckets=32,
relative_max_distance=128,
disable_ngram_loss=False,
eps=0.0,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
activation_dropout: Optional[float] = 0.1,
activation_function: Optional[Union[str, Callable]] = "gelu",
vocab_size: Optional[int] = 30522,
hidden_size: Optional[int] = 1024,
encoder_ffn_dim: Optional[int] = 4096,
num_encoder_layers: Optional[int] = 12,
num_encoder_attention_heads: Optional[int] = 16,
decoder_ffn_dim: Optional[int] = 4096,
num_decoder_layers: Optional[int] = 12,
num_decoder_attention_heads: Optional[int] = 16,
attention_dropout: Optional[float] = 0.1,
dropout: Optional[float] = 0.1,
max_position_embeddings: Optional[int] = 512,
init_std: Optional[float] = 0.02,
is_encoder_decoder: Optional[bool] = True,
add_cross_attention: Optional[bool] = True,
decoder_start_token_id: Optional[int] = 0,
ngram: Optional[int] = 2,
num_buckets: Optional[int] = 32,
relative_max_distance: Optional[int] = 128,
disable_ngram_loss: Optional[bool] = False,
eps: Optional[float] = 0.0,
use_cache: Optional[bool] = True,
pad_token_id: Optional[int] = 0,
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
**kwargs
):
self.vocab_size = vocab_size
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):

If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`):
last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
Expand Down Expand Up @@ -590,7 +590,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
the forward function.
"""

def __init__(self, config: ProphetNetConfig):
def __init__(self, config: ProphetNetConfig) -> None:
self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)

Expand Down Expand Up @@ -1407,7 +1407,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
embeddings instead of randomly initialized word embeddings.
"""

def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None):
def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
super().__init__(config)

self.ngram = config.ngram
Expand Down Expand Up @@ -1769,7 +1769,7 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetModel(ProphetNetPreTrainedModel):
def __init__(self, config):
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

Expand Down Expand Up @@ -2106,7 +2106,7 @@ def get_decoder(self):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def __init__(self, config):
def __init__(self, config: ProphetNetConfig):
# set config for CLM
config = copy.deepcopy(config)
config.is_decoder = True
Expand Down Expand Up @@ -2341,7 +2341,7 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
classes.
"""

def __init__(self, config):
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.decoder = ProphetNetDecoder(config)

Expand Down
35 changes: 19 additions & 16 deletions src/transformers/models/prophetnet/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import collections
import os
from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
Expand Down Expand Up @@ -111,17 +111,17 @@ class ProphetNetTokenizer(PreTrainedTokenizer):

def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
x_sep_token="[X_SEP]",
pad_token="[PAD]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
vocab_file: str,
do_lower_case: Optional[bool] = True,
do_basic_tokenize: Optional[bool] = True,
never_split: Optional[Iterable] = None,
unk_token: Optional[str] = "[UNK]",
sep_token: Optional[str] = "[SEP]",
x_sep_token: Optional[str] = "[X_SEP]",
pad_token: Optional[str] = "[PAD]",
mask_token: Optional[str] = "[MASK]",
tokenize_chinese_chars: Optional[bool] = True,
strip_accents: Optional[bool] = None,
**kwargs
):
super().__init__(
Expand Down Expand Up @@ -177,21 +177,24 @@ def _tokenize(self, text):
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens

def _convert_token_to_id(self, token):
def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))

def _convert_id_to_token(self, index):
def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)

def convert_tokens_to_string(self, tokens):
def convert_tokens_to_string(self, tokens: str):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string

def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: Optional[bool] = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
Expand Down