diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 62451be24247..5ec1cee2e692 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -125,6 +125,7 @@ ], "models": [], # Models + "models.megatron": ["MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronConfig", "MegatronTokenizer"], "models.wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", "Wav2Vec2Tokenizer"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], @@ -279,6 +280,7 @@ # tokenziers-backed objects if is_tokenizers_available(): # Fast tokenizers + _import_structure["models.megatron"].append("MegatronTokenizerFast") _import_structure["models.convbert"].append("ConvBertTokenizerFast") _import_structure["models.albert"].append("AlbertTokenizerFast") _import_structure["models.bart"].append("BartTokenizerFast") @@ -367,6 +369,15 @@ _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure + _import_structure["models.megatron"].extend( + [ + "MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST", + "MegatronForCausalLM", + "MegatronForSequenceClassification", + "MegatronModel", + ] + ) + _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1322,6 +1333,7 @@ from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer from .models.marian import MarianConfig from .models.mbart import MBartConfig + from .models.megatron import MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronConfig, MegatronTokenizer from .models.mmbt import MMBTConfig from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer @@ -1435,6 +1447,7 @@ from .models.longformer import LongformerTokenizerFast from .models.lxmert import LxmertTokenizerFast from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast + from .models.megatron import MegatronTokenizerFast from .models.mobilebert import MobileBertTokenizerFast from .models.mpnet import MPNetTokenizerFast from .models.mt5 import MT5TokenizerFast @@ -1733,6 +1746,12 @@ MBartForSequenceClassification, MBartModel, ) + from .models.megatron import ( + MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST, + MegatronForCausalLM, + MegatronForSequenceClassification, + MegatronModel, + ) from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mobilebert import ( MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f7f9a9e58ded..c54198223aea 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -47,6 +47,7 @@ lxmert, marian, mbart, + megatron, mmbt, mobilebert, mpnet, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index a5dcff3b3aab..848e398ef352 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -46,6 +46,7 @@ from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig from ..marian.configuration_marian import MarianConfig from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig +from ..megatron.configuration_megatron import MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronConfig from ..mobilebert.configuration_mobilebert import MobileBertConfig from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig from ..mt5.configuration_mt5 import MT5Config @@ -74,6 +75,7 @@ (key, value) for pretrained_map in [ # Add archive maps here + MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LED_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -118,6 +120,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("megatron", MegatronConfig), ("wav2vec2", Wav2Vec2Config), ("convbert", ConvBertConfig), ("led", LEDConfig), @@ -168,6 +171,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("megatron", "Megatron"), ("wav2vec2", "Wav2Vec2"), ("convbert", "ConvBERT"), ("led", "LED"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 64c527e1ec8a..b389b4dff56f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -21,8 +21,6 @@ from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings from ...utils import logging - -# Add modeling imports here from ..albert.modeling_albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, @@ -66,8 +64,6 @@ CamembertForTokenClassification, CamembertModel, ) - -# Add modeling imports here from ..convbert.modeling_convbert import ( ConvBertForMaskedLM, ConvBertForMultipleChoice, @@ -158,6 +154,14 @@ MBartForSequenceClassification, MBartModel, ) + +# Add modeling imports here +# Add modeling imports here +from ..megatron.modeling_megatron import ( + MegatronForCausalLM, + MegatronForSequenceClassification, + MegatronModel, +) from ..mobilebert.modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, @@ -276,6 +280,7 @@ LxmertConfig, MarianConfig, MBartConfig, + MegatronConfig, MobileBertConfig, MPNetConfig, MT5Config, @@ -304,6 +309,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping + (MegatronConfig, MegatronModel), (Wav2Vec2Config, Wav2Vec2Model), (ConvBertConfig, ConvBertModel), (LEDConfig, LEDModel), @@ -424,6 +430,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Causal LM mapping + (MegatronConfig, MegatronForCausalLM), (CamembertConfig, CamembertForCausalLM), (XLMRobertaConfig, XLMRobertaForCausalLM), (RobertaConfig, RobertaForCausalLM), @@ -501,6 +508,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Sequence Classification mapping + (MegatronConfig, MegatronForSequenceClassification), (ConvBertConfig, ConvBertForSequenceClassification), (LEDConfig, LEDForSequenceClassification), (DistilBertConfig, DistilBertForSequenceClassification), diff --git a/src/transformers/models/megatron/__init__.py b/src/transformers/models/megatron/__init__.py new file mode 100644 index 000000000000..ef3573283313 --- /dev/null +++ b/src/transformers/models/megatron/__init__.py @@ -0,0 +1,74 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available + + +_import_structure = { + "configuration_megatron": ["MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronConfig"], + "tokenization_megatron": ["MegatronTokenizer"], +} + +if is_tokenizers_available(): + _import_structure["tokenization_megatron_fast"] = ["MegatronTokenizerFast"] + +if is_torch_available(): + _import_structure["modeling_megatron"] = [ + "MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST", + "MegatronForCausalLM", + "MegatronForSequenceClassification", + "MegatronModel", + "MegatronPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_megatron import MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronConfig + from .tokenization_megatron import MegatronTokenizer + + if is_tokenizers_available(): + from .tokenization_megatron_fast import MegatronTokenizerFast + + if is_torch_available(): + from .modeling_megatron import ( + MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST, + MegatronForCausalLM, + MegatronForSequenceClassification, + MegatronModel, + MegatronPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/megatron/configuration_megatron.py b/src/transformers/models/megatron/configuration_megatron.py new file mode 100644 index 000000000000..92750e4c9996 --- /dev/null +++ b/src/transformers/models/megatron/configuration_megatron.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Megatron model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "megatron-11b": "https://huggingface.co/anton-l/megatron-11b/resolve/main/config.json", + # See all Megatron models at https://huggingface.co/models?filter=megatron +} + + +class MegatronConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.MegatronModel`. It is used to + instantiate an Megatron model according to the specified arguments, defining the model architecture. Instantiating + a configuration with the defaults will yield a similar configuration to that of the Megatron `megatron-11b + `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50265): + Vocabulary size of the Megatron model. Defines the number of different tokens that can be represented by + the :obj:`inputs_ids` passed when calling :class:`~transformers.MegatronModel` or + :class:`~transformers.TFMegatronModel`. + d_model (:obj:`int`, `optional`, defaults to 1024): + Dimensionality of the layers and the pooler layer. + decoder_layers (:obj:`int`, `optional`, defaults to 12): + Number of decoder layers. + decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (:obj:`int`, `optional`, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example:: + + >>> from transformers import MegatronModel, MegatronConfig + + >>> # Initializing a Megatron megatron-11b style configuration + >>> configuration = MegatronConfig() + + >>> # Initializing a model from the megatron-11b style configuration + >>> model = MegatronModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "megatron" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=51200, + max_position_embeddings=1024, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=False, + is_decoder=True, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=0, + classifier_dropout=0.0, + scale_embedding=True, + gradient_checkpointing=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + is_decoder=is_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.gradient_checkpointing = gradient_checkpointing + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + @property + def num_attention_heads(self) -> int: + return self.decoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model diff --git a/src/transformers/models/megatron/convert_megatron_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/megatron/convert_megatron_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..5451148dddc1 --- /dev/null +++ b/src/transformers/models/megatron/convert_megatron_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert fairseq Megatron checkpoint.""" + + +import argparse +import glob +import os +from pathlib import Path + +import fairseq +import torch +from fairseq import tasks +from fairseq.data.encoders.gpt2_bpe import get_encoder +from fairseq.data.encoders.gpt2_bpe_utils import Encoder +from fairseq.dataclass.utils import convert_namespace_to_omegaconf, overwrite_args_by_name +from fairseq.models.transformer_lm import TransformerLanguageModel +from packaging import version + +from transformers import MegatronConfig, MegatronForCausalLM, MegatronTokenizer +from transformers.utils import logging + + +if version.parse(fairseq.__version__) < version.parse("0.9.0"): + raise Exception("requires fairseq >= 0.9.0") + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +SAMPLE_TEXT = " Hello world! 💁 (╯°□°)╯︵ ┻━┻)" + + +def load_fairseq_checkpoint(checkpoint_path): + """ + Load a model-parallel checkpoint to CPU. + """ + # the checkpoint is partitioned into N_GPUS model-parallel shards + shard_paths = sorted(glob.glob(os.path.join(checkpoint_path, "model-model_part-*.pt"))) + + # load all checkpoints into memory + state_shards = [torch.load(p, map_location="cpu") for p in shard_paths] + + # the model config is copied across shards, so we can use any one of them + output_state = {"args": state_shards[0]["args"]} + print(output_state["args"]) + output_model = {} + + # group model parameters by key + model_shards = {k: [st["model"][k] for st in state_shards] for k in state_shards[0]["model"].keys()} + + for k, params in model_shards.items(): + if ".version" in k: + continue + elif "_float_tensor" in k: + # embed_positions are copied across shards + output_model[k] = params[0] + elif "out_proj.weight" in k or "fc2.weight" in k: + # row parallel weights + output_model[k] = torch.cat(params, dim=1) + elif "out_proj.bias" in k or "fc2.bias" in k: + # biases are copied across shards + output_model[k] = params[0] + elif "_proj" in k or "fc1" in k: + # column parallel weights + output_model[k] = torch.cat(params, dim=0) + elif "_norm" in k: + # norms are copied across shards + output_model[k] = params[0] + elif "embed_tokens" in k: + output_model[k] = torch.cat(params, dim=0) + else: + raise ValueError(f"We don't know how to merge the parameter {k}") + + output_state["model"] = output_model + return output_state + + +def load_fairseq_model(checkpoint_path): + logger.info("Converting fairseq's model_parallel checkpoint into a single model...") + state = load_fairseq_checkpoint(checkpoint_path) + cfg = convert_namespace_to_omegaconf(state["args"]) + # specify the path to `dict.txt` which is needed for embeddings initialization + overwrite_args_by_name(cfg, {"task": {"data": checkpoint_path}}) + task = tasks.setup_task(cfg.task) + logger.info("Initializing faiseq model from state dict...") + model = TransformerLanguageModel.build_model(cfg.model, task).eval() + model.load_state_dict(state["model"]) + model.upgrade_state_dict(model.state_dict()) + return cfg, model, task + + +def fairseq_tokenize(text, tokenizer, task): + # The raw tokenizer does not contain special tokens + # and the bpe indices are not the ones that model expects. + # Fairseq rearranges them according to frequency and saves + # the id mappings into dict.txt + tokens = " ".join([str(t) for t in tokenizer.encode(text)]) + bpe_sentence = " " + tokens + " " + # encode bpe tokens into model's input ids + tokens = task.source_dictionary.encode_line(bpe_sentence, append_eos=False) + return tokens.long() + + +@torch.no_grad() +def convert_fairseq_checkpoint(checkpoint_path, pytorch_dump_path): + """ + Convert tokenizer files. Copy/paste/tweak model's weights to our BART structure. + """ + Path(pytorch_dump_path).mkdir(exist_ok=True) + encoder_path = os.path.join(checkpoint_path, "encoder.json") + bpe_path = os.path.join(checkpoint_path, "vocab.bpe") + + logger.info("Loading the fairseq model...") + fs_cfg, fs_model, fs_task = load_fairseq_model(checkpoint_path) + fs_tokenizer: Encoder = get_encoder(encoder_path, bpe_path) + + fs_tokens = fairseq_tokenize(SAMPLE_TEXT, fs_tokenizer, fs_task) + fs_tokens = fs_tokens.unsqueeze(0) + + tokenizer = MegatronTokenizer.from_pretrained("megatron-11b") + hf_tokens = tokenizer.encode(SAMPLE_TEXT, return_tensors="pt") + assert torch.eq(fs_tokens, hf_tokens).all() + + state_dict = fs_model.state_dict() + state_dict.pop("decoder.version", None) + # megatron uses fixed sinusoidal embeddings, this is just a dummy tensor + state_dict.pop("decoder.embed_positions._float_tensor", None) + # equivalent to lm_head with tied embeddings + state_dict.pop("decoder.output_projection.weight", None) + + config = MegatronConfig( + vocab_size=51200, + max_position_embeddings=fs_cfg.model.max_target_positions, + decoder_layers=fs_cfg.model.decoder_layers, + decoder_ffn_dim=fs_cfg.model.decoder_ffn_embed_dim, + decoder_attention_heads=fs_cfg.model.decoder_attention_heads, + decoder_layerdrop=fs_cfg.model.decoder_layerdrop, + is_decoder=True, + activation_function=fs_cfg.model.activation_fn, + d_model=fs_cfg.model.decoder_embed_dim, + dropout=fs_cfg.model.dropout, + attention_dropout=fs_cfg.model.attention_dropout, + decoder_start_token_id=0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + scale_embedding=not fs_cfg.model.no_scale_embedding, + ) + logger.info("Initializing the 🤗 Megatron model") + model = MegatronForCausalLM(config).eval() + logger.info("Loading fairseq's state disct into Megatron") + missing, extra = model.model.load_state_dict(state_dict, strict=False) + unexpected_missing = [k for k in missing if k != "decoder.embed_positions.weight"] + assert unexpected_missing == [], f"Missing key(s) in state_dict: {unexpected_missing}" + assert extra == [], f"Extra keys in the original state_dict: {extra}" + + fairseq_outputs = fs_model.extract_features(fs_tokens, encoder_out=None)[0] + new_model_outputs = model.model(hf_tokens)[0] + + # Check results + assert fairseq_outputs.shape == new_model_outputs.shape + assert torch.eq(fairseq_outputs, new_model_outputs).all() + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--fairseq_path", type=str, help="Path to the unpacked fairseq Megatron model checkpoint.") + parser.add_argument("--pytorch_dump_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_fairseq_checkpoint(args.fairseq_path, args.pytorch_dump_path) diff --git a/src/transformers/models/megatron/modeling_megatron.py b/src/transformers/models/megatron/modeling_megatron.py new file mode 100755 index 000000000000..f7d7a228be3d --- /dev/null +++ b/src/transformers/models/megatron/modeling_megatron.py @@ -0,0 +1,1214 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Megatron model. """ + +import copy +import math +import random +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_megatron import MegatronConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MegatronConfig" +_TOKENIZER_FOR_DOC = "MegatronTokenizer" + +MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "megatron-11b", + # See all Megatron models at https://huggingface.co/models?filter=megatron +] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +# Copied from transformers.models.fsmt.modeling_fsmt.SinusoidalPositionalEmbedding +class SinusoidalPositionalEmbedding(nn.Embedding): + """ + This module produces sinusoidal positional embeddings of any length. + + We don't want to save the weight of this embedding since it's not trained (deterministic) and it can be huge. + + Padding symbols are ignored. + + These embeddings get automatically extended in forward if more positions is needed. + """ + + def __init__(self, num_positions, embedding_dim, padding_idx): + self.make_weight(num_positions, embedding_dim, padding_idx) + + def make_weight(self, num_positions, embedding_dim, padding_idx): + weight = self.get_embedding(num_positions, embedding_dim, padding_idx) + if not hasattr(self, "weight"): + # in ___init__ + super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight) + else: + # in forward + weight = weight.to(self.weight.device) + self.weight = nn.Parameter(weight) + self.weight.detach_() + self.weight.requires_grad = False + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + @staticmethod + def make_positions(tensor, padding_idx: int): + """ + Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx + + def forward( + self, + input, + incremental_state: Optional[Any] = None, + timestep: Optional[Tensor] = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weight.size(0): + # expand embeddings if needed + self.make_weight(max_pos, self.embedding_dim, self.padding_idx) + positions = self.make_positions(input, self.padding_idx) + return super().forward(positions) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Megatron +class MegatronAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class MegatronDecoderLayer(nn.Module): + def __init__(self, config: MegatronConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = MegatronAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + # TODO: remove + # self.encoder_attn = MegatronAttention( + # self.embed_dim, + # config.decoder_attention_heads, + # dropout=config.attention_dropout, + # is_decoder=True, + # ) + # self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + # if encoder_hidden_states is not None: + # residual = hidden_states + # hidden_states = self.encoder_attn_layer_norm(hidden_states) + # + # # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + # cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + # hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + # hidden_states=hidden_states, + # key_value_states=encoder_hidden_states, + # attention_mask=encoder_attention_mask, + # layer_head_mask=layer_head_mask, + # past_key_value=cross_attn_past_key_value, + # output_attentions=output_attentions, + # ) + # hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + # hidden_states = residual + hidden_states + # + # # add cross-attn to positions 3,4 of present_key_value tuple + # present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->Megatron +class MegatronClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class MegatronPreTrainedModel(PreTrainedModel): + config_class = MegatronConfig + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + "model.decoder.embed_positions.weight", + ] + _keys_to_ignore_on_save = [ + "model.decoder.embed_positions.weight", + ] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, SinusoidalPositionalEmbedding): + pass + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +MEGATRON_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.MegatronConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +MEGATRON_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.MegatronTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +MEGATRON_STANDALONE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class MegatronDecoder(MegatronPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`MegatronDecoderLayer` + + Args: + config: MegatronConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: MegatronConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + embed_dim = self.embed_tokens.embedding_dim + + self.embed_positions = SinusoidalPositionalEmbedding( + config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx + ) + self.layers = nn.ModuleList([MegatronDecoderLayer(config) for _ in range(config.decoder_layers)]) + # TODO: remove + # self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.init_weights() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.MegatronTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_ids) + + hidden_states = inputs_embeds + positions + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare Megatron Model outputting raw hidden-states without any specific head on top.", + MEGATRON_START_DOCSTRING, +) +class MegatronModel(MegatronPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.decoder = MegatronDecoder(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(MEGATRON_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="megatron-11b", + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=head_mask, + encoder_head_mask=None, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return decoder_outputs + + +@add_start_docstrings( + """ + Megatron model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + MEGATRON_START_DOCSTRING, +) +class MegatronForSequenceClassification(MegatronPreTrainedModel): + def __init__(self, config: MegatronConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = MegatronModel(config) + self.classification_head = MegatronClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + self.model._init_weights(self.classification_head.dense) + self.model._init_weights(self.classification_head.out_proj) + + @add_start_docstrings_to_model_forward(MEGATRON_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="anton-l/megatron-11b", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id) + + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.decoder_hidden_states, + attentions=outputs.decoder_attentions, + ) + + +@add_start_docstrings( + """ + The Megatron model with a language modeling head on top (linear layer with weights tied to the input embeddings). + """, + MEGATRON_START_DOCSTRING, +) +class MegatronForCausalLM(MegatronPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = MegatronModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.MegatronTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import MegatronTokenizer, MegatronForCausalLM + + >>> tokenizer = MegatronTokenizer.from_pretrained('megatron-11b') + >>> model = MegatronForCausalLM.from_pretrained('megatron-11b', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/megatron/tokenization_megatron.py b/src/transformers/models/megatron/tokenization_megatron.py new file mode 100644 index 000000000000..53c21a285d55 --- /dev/null +++ b/src/transformers/models/megatron/tokenization_megatron.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Megatron.""" +from ...utils import logging +from ..bart.tokenization_bart import BartTokenizer + + +logger = logging.get_logger(__name__) + +# vocab and merges same as roberta +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "megatron-11b": "https://huggingface.co/roberta-large/resolve/main/vocab.json", + }, + "merges_file": { + "megatron-11b": "https://huggingface.co/roberta-large/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "megatron-11b": 1024, +} + + +class MegatronTokenizer(BartTokenizer): + """ + Construct a Megatron tokenizer. + + :class:`~transformers.MegatronTokenizer` is identical to :class:`~transformers.BartTokenizer` and runs end-to-end + tokenization: punctuation splitting and wordpiece. + + Refer to superclass :class:`~transformers.BartTokenizer` for usage examples and documentation concerning + parameters. + """ + + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES diff --git a/src/transformers/models/megatron/tokenization_megatron_fast.py b/src/transformers/models/megatron/tokenization_megatron_fast.py new file mode 100644 index 000000000000..ec039a6ddb43 --- /dev/null +++ b/src/transformers/models/megatron/tokenization_megatron_fast.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Megatron.""" +from ...utils import logging +from ..bart.tokenization_bart_fast import BartTokenizerFast +from .tokenization_megatron import MegatronTokenizer + + +logger = logging.get_logger(__name__) + +# vocab and merges same as roberta +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "megatron-11b": "https://huggingface.co/roberta-large/resolve/main/vocab.json", + }, + "merges_file": { + "megatron-11b": "https://huggingface.co/roberta-large/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "megatron-11b": 1024, +} + + +class MegatronTokenizerFast(BartTokenizerFast): + r""" + Construct a "fast" Megatron tokenizer (backed by HuggingFace's `tokenizers` library). + + :class:`~transformers.MegatronTokenizerFast` is identical to :class:`~transformers.BartTokenizerFast` and runs + end-to-end tokenization: punctuation splitting and wordpiece. + + Refer to superclass :class:`~transformers.BartTokenizerFast` for usage examples and documentation concerning + parameters. + """ + + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = MegatronTokenizer diff --git a/tests/test_modeling_megatron.py b/tests/test_modeling_megatron.py new file mode 100644 index 000000000000..b941c58e3bcd --- /dev/null +++ b/tests/test_modeling_megatron.py @@ -0,0 +1,603 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Megatron model. """ + + +import copy +import tempfile +import unittest + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + MegatronConfig, + MegatronForCausalLM, + MegatronForSequenceClassification, + MegatronModel, + MegatronTokenizer, + ) + from transformers.models.megatron.modeling_megatron import MegatronDecoder, MegatronEncoder + + +def prepare_megatron_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + + +@require_torch +class MegatronModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( + 3, + ) + input_ids[:, -1] = self.eos_token_id # Eos Token + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = MegatronConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + ) + inputs_dict = prepare_megatron_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = MegatronModel(config=config).get_decoder().to(torch_device).eval() + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict["attention_mask"] + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = MegatronModel(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = MegatronEncoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ + 0 + ] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = MegatronDecoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=inputs_dict["attention_mask"], + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +@require_torch +class MegatronModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + MegatronModel, + MegatronForSequenceClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_missing_keys = False + + def setUp(self): + self.model_tester = MegatronModelTester(self) + self.config_tester = ConfigTester(self, config_class=MegatronConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + # MegatronForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (MegatronModel,): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = MegatronForCausalLM(config).eval().to(torch_device) + if torch_device == "cuda": + model.half() + model.generate(input_ids, attention_mask=attention_mask) + model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + +def assert_tensors_close(a, b, atol=1e-12, prefix=""): + """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item() + if a.numel() > 100: + msg = f"tensor values are {pct_different:.1%} percent different." + else: + msg = f"{a} != {b}" + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) + + +TOLERANCE = 1e-4 + + +@require_torch +@require_sentencepiece +@require_tokenizers +@slow +class MegatronModelIntegrationTests(unittest.TestCase): + @cached_property + def default_tokenizer(self): + return MegatronTokenizer.from_pretrained("megatron-11b") + + def test_inference_no_head(self): + model = MegatronModel.from_pretrained("megatron-11b").to(torch_device) + input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + decoder_input_ids = _long_tensor([[2, 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588]]) + inputs_dict = prepare_megatron_inputs_dict(model.config, input_ids, decoder_input_ids) + with torch.no_grad(): + output = model(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, 1024)) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = torch.tensor( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_inference_head(self): + model = MegatronForCausalLM.from_pretrained("megatron-11b").to(torch_device) + + # change to intended input + input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + inputs_dict = prepare_megatron_inputs_dict(model.config, input_ids, decoder_input_ids) + with torch.no_grad(): + output = model(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, model.config.vocab_size)) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = torch.tensor( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_seq_to_seq_generation(self): + hf = MegatronForCausalLM.from_pretrained("megatron-11b").to(torch_device) + tok = MegatronTokenizer.from_pretrained("megatron-11b") + + batch_input = [ + # string 1, + # string 2, + # string 3, + # string 4, + ] + + # The below article tests that we don't add any hypotheses outside of the top n_beams + dct = tok.batch_encode_plus( + batch_input, + max_length=512, + padding="max_length", + truncation_strategy="only_first", + truncation=True, + return_tensors="pt", + ) + + hypotheses_batch = hf.generate( + input_ids=dct["input_ids"].to(torch_device), + attention_mask=dct["attention_mask"].to(torch_device), + num_beams=2, + ) + + EXPECTED = [ + # here expected 1, + # here expected 2, + # here expected 3, + # here expected 4, + ] + + generated = tok.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated == EXPECTED + + +class MegatronStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = MegatronConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = MegatronDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = MegatronDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class MegatronStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (MegatronDecoder, MegatronForCausalLM) if is_torch_available() else () + all_generative_model_classes = (MegatronForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = MegatronStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=MegatronConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/utils/check_repo.py b/utils/check_repo.py index f9c25dabc194..288cf4e89afd 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -30,6 +30,9 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = [ # models to ignore for not tested + "MegatronEncoder", # Building part of bigger (tested) model. + "MegatronDecoder", # Building part of bigger (tested) model. + "MegatronDecoderWrapper", # Building part of bigger (tested) model. "LEDEncoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model. "BartDecoderWrapper", # Building part of bigger (tested) model. @@ -75,6 +78,9 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = [ # models to ignore for model xxx mapping + "MegatronEncoder", + "MegatronDecoder", + "MegatronDecoderWrapper", "LEDEncoder", "LEDDecoder", "BartDecoder",